MLIR  20.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
27 #include <optional>
28 #include <utility>
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 //===---------------------------------------------------------------------===//
39 // Methods and patterns that fuse elementwise `linalg.generic` operations.
40 //===---------------------------------------------------------------------===//
41 
42 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
43 /// the `producer` to use in the fused operation given the indexing map of the
44 /// result of the producer in the consumer.
46  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
47  AffineMap fusedConsumerArgIndexMap) {
48  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
49  // from consumer loop -> consumer arg tensor index/producer result tensor
50  // index. The fused loop is same as the consumer loop. For each producer arg
51  // the indexing map to be computed is a map from consumer loop -> producer
52  // arg tensor index.
53  // producerResultIndexMap is a map from producer loop -> tensor index.
54  // Compute the inverse to get map from tensor index -> producer loop.
55  // The inverse is a map from producer result tensor index -> producer loop.
56  AffineMap invProducerResultIndexMap =
57  inversePermutation(producerResultIndexMap);
58  assert(invProducerResultIndexMap &&
59  "expected producer result indexing map to be invertible");
60 
61  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
62  // argMap is a map from producer loop -> producer arg tensor index.
63  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
64 
65  // Compose argMap with invProducerResultIndexMap to get a map from
66  // producer result tensor index -> producer arg tensor index.
67  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
68 
69  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
70  // consumer loop/ fused loop -> producer arg tensor index.
71  return t1.compose(fusedConsumerArgIndexMap);
72 }
73 
74 // Checks if the given operand can be dropped, and the remaining operands
75 // of the fused producer & consumer after the fusion can still compute the
76 // bounds of the op.
78  GenericOp producer, GenericOp consumer,
79  ArrayRef<OpOperand *> opOperandsToIgnore) {
80  SmallVector<AffineMap> indexingMaps;
81 
82  SmallVector<GenericOp> ops = {producer, consumer};
83  for (auto &op : ops) {
84  for (auto &opOperand : op->getOpOperands()) {
85  if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
86  continue;
87  }
88  indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
89  }
90  }
91 
92  // The concatanation of the remained indexing maps must be invertible, so
93  // the bounds of the op can be still computed after dropping the selected
94  // operand. inversePermutation returns an empty AffineMap in case the
95  // concatanated indexing maps are not invertible.
96  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
97 }
98 
99 /// Returns a set of indices of the producer's results which would
100 /// be preserved after the fusion.
101 /// * There is a chance that the implementation of the transformation does not
102 /// agree with the result of this method. This function gives a prediction based
103 /// on an optimized fusion.
105  GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
106  llvm::SmallDenseSet<int> preservedProducerResults;
107  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
108 
109  // The fusedOperand will be removed during the fusion
110  opOperandsToIgnore.emplace_back(fusedOperand);
111 
112  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
113  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
114  opOperandsToIgnore.emplace_back(outputOperand);
115  if (producer.payloadUsesValueFromOperand(outputOperand) ||
116  !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
117  opOperandsToIgnore) ||
118  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
119  return user != consumer.getOperation();
120  })) {
121  preservedProducerResults.insert(producerResult.index());
122 
123  // In case the operand can't be dropped
124  (void)opOperandsToIgnore.pop_back_val();
125  }
126  }
127  return preservedProducerResults;
128 }
129 
130 /// Conditions for elementwise fusion of generic operations.
132  if (!fusedOperand)
133  return false;
134 
135  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
136  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
137 
138  // Check producer and consumer are generic ops.
139  if (!producer || !consumer)
140  return false;
141 
142  // Consumer can have mixed semantics, just check operand itself has tensor
143  // type. Producer must have full tensor semantics to avoid potential
144  // aliasing between producer and consumer memrefs.
145  if (!producer.hasPureTensorSemantics() ||
146  !isa<RankedTensorType>(fusedOperand->get().getType()))
147  return false;
148 
149  // Verify that
150  // - the producer has all "parallel" iterator type.
151  if (producer.getNumParallelLoops() != producer.getNumLoops())
152  return false;
153 
154  // Only allow fusing the producer of an input operand for now.
155  // TODO: allow fusing the producer of an output operand.
156  if (!consumer.isDpsInput(fusedOperand))
157  return false;
158 
159  // Get the consumer index map. The number of results of the consumer index
160  // map must match the number of loops of the producer.
161  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
162  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
163  return false;
164 
165  // Finally the index_map for the result must be invertible. For now just
166  // verify it is a permutation.
167  AffineMap producerResultIndexMap =
168  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
169  if (!producerResultIndexMap.isPermutation())
170  return false;
171 
172  // Ensure that the fusion does not remove size information required to
173  // get the loop bounds. For non-reduction generics, this is trivially the
174  // case due to the output operand. For reductions, we need to check that after
175  // the fusion, each loop dimension has at least one input that defines it.
176  if ((consumer.getNumReductionLoops())) {
177  BitVector coveredDims(consumer.getNumLoops(), false);
178 
179  auto addToCoveredDims = [&](AffineMap map) {
180  for (auto result : map.getResults())
181  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
182  coveredDims[dimExpr.getPosition()] = true;
183  };
184 
185  for (auto pair :
186  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
187  Value operand = std::get<0>(pair);
188  if (operand == fusedOperand->get())
189  continue;
190  AffineMap operandMap = std::get<1>(pair);
191  addToCoveredDims(operandMap);
192  }
193 
194  for (OpOperand *operand : producer.getDpsInputOperands()) {
195  AffineMap newIndexingMap =
197  operand, producerResultIndexMap, consumerIndexMap);
198  addToCoveredDims(newIndexingMap);
199  }
200  if (!coveredDims.all())
201  return false;
202  }
203 
204  return true;
205 }
206 
207 /// Generate the region of the fused tensor operation. The region of the fused
208 /// op must be empty.
210  RewriterBase &rewriter, GenericOp fusedOp,
211  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
212  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
213  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
214  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
215  // Build the region of the fused op.
216  Block &producerBlock = producer->getRegion(0).front();
217  Block &consumerBlock = consumer->getRegion(0).front();
218  OpBuilder::InsertionGuard guard(rewriter);
219  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
220  IRMapping mapper;
221 
222  // 2. Add an index operation for every fused loop dimension and use the
223  // `consumerToProducerLoopsMap` to map the producer indices.
224  if (producer.hasIndexSemantics()) {
225  // Add an index operation for every fused loop dimension.
226  unsigned numFusedOpLoops =
227  std::max(producer.getNumLoops(), consumer.getNumLoops());
228  SmallVector<Value> fusedIndices;
229  fusedIndices.reserve(numFusedOpLoops);
230  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
231  std::back_inserter(fusedIndices), [&](uint64_t dim) {
232  return rewriter.create<IndexOp>(producer.getLoc(), dim);
233  });
234  for (IndexOp indexOp :
235  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
236  Value newIndex = rewriter.create<affine::AffineApplyOp>(
237  producer.getLoc(),
238  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
239  mapper.map(indexOp.getResult(), newIndex);
240  }
241  }
242  // TODO: allow fusing the producer of an output operand.
243  assert(consumer.isDpsInput(fusedOperand) &&
244  "expected producer of input operand");
245  // 3. Consumer input operands up to consumerIdx (exclusive).
246  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
247  fusedOperand->getOperandNumber())) // input assumption.
248  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
249 
250  // Replacing consumerIdx requires getting the cloned, yielded, value from
251  // the (cloned) producer block. This happens in step 9.
252 
253  // 4. Splice in producer's input operands.
254  for (BlockArgument bbArg :
255  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
256  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257 
258  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
259  for (BlockArgument bbArg :
260  consumerBlock.getArguments()
261  .take_front(consumer.getNumDpsInputs())
262  .drop_front(fusedOperand->getOperandNumber() + 1))
263  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
264 
265  // 6. All of the producer's output operands
266  for (const auto &bbArg : llvm::enumerate(
267  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
268  if (!preservedProducerResults.count(bbArg.index()))
269  continue;
270  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
271  bbArg.value().getLoc()));
272  }
273 
274  // 7. All of consumer's output operands.
275  for (BlockArgument bbArg :
276  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
277  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
278 
279  // 8. Clone all producer operations except for the yield and index operations
280  // to the fused operation.
281  for (auto &op : producerBlock.without_terminator()) {
282  if (!isa<IndexOp>(op))
283  rewriter.clone(op, mapper);
284  }
285  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
286  // forward the yield operand.
287  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
288  unsigned producerResultNumber =
289  cast<OpResult>(fusedOperand->get()).getResultNumber();
290  Value replacement =
291  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
292 
293  // Sanity checks, if replacement is not already in the mapper then it must be
294  // produced outside.
295  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
296  if (auto bb = dyn_cast<BlockArgument>(replacement))
297  assert(bb.getOwner() != &producerBlock &&
298  "yielded block argument must have been mapped");
299  else
300  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
301  "yielded value must have been mapped");
302  }
303  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
304  replacement);
305  // 10. Clone operations from the consumer to the fused op.
306  for (auto &op : consumerBlock.without_terminator())
307  rewriter.clone(op, mapper);
308 
309  // 11. Include the final yield (which is the remapped values for all the
310  // yield)
311  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
312  SmallVector<Value> fusedYieldValues;
313  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
314  consumerYieldOp.getNumOperands());
315  for (const auto &producerYieldVal :
316  llvm::enumerate(producerYieldOp.getOperands())) {
317  if (preservedProducerResults.count(producerYieldVal.index()))
318  fusedYieldValues.push_back(
319  mapper.lookupOrDefault(producerYieldVal.value()));
320  }
321  for (auto consumerYieldVal : consumerYieldOp.getOperands())
322  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
323  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
324 
325  // Sanity checks.
326  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
327  "Ill-formed GenericOp region");
328 }
329 
330 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
332  OpOperand *fusedOperand) {
333  assert(areElementwiseOpsFusable(fusedOperand) &&
334  "expected elementwise operation pre-conditions to pass");
335  auto producerResult = cast<OpResult>(fusedOperand->get());
336  auto producer = cast<GenericOp>(producerResult.getOwner());
337  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
338  // TODO: allow fusing the producer of an output operand.
339  assert(consumer.isDpsInput(fusedOperand) &&
340  "expected producer of input operand");
341  /// Find the results of the producer that have uses outside of the consumer,
342  /// after the fusion.
343  llvm::SmallDenseSet<int> preservedProducerResults =
345  fusedOperand);
346 
347  // Compute the fused operands list and indexing maps.
348  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
349  SmallVector<Type> fusedResultTypes;
350  SmallVector<AffineMap> fusedIndexMaps;
351  fusedInputOperands.reserve(producer.getNumDpsInputs() +
352  consumer.getNumDpsInputs());
353  fusedOutputOperands.reserve(preservedProducerResults.size() +
354  consumer.getNumDpsInits());
355  fusedResultTypes.reserve(preservedProducerResults.size() +
356  consumer.getNumDpsInits());
357  fusedIndexMaps.reserve(producer->getNumOperands() +
358  consumer->getNumOperands());
359  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
360  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
361  auto consumerInputs = consumer.getDpsInputOperands();
362  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
363  return operand == fusedOperand;
364  });
365  assert(it != consumerInputs.end() && "expected to find the consumer operand");
366  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
367  fusedInputOperands.push_back(opOperand->get());
368  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
369  }
370  // 4. Splice in producer's input operands/maps.
371  AffineMap producerResultIndexMap =
372  producer.getIndexingMapMatchingResult(producerResult);
373  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
374  fusedInputOperands.push_back(opOperand->get());
375  // Compute indexing maps for the producer args in the fused operation.
377  opOperand, producerResultIndexMap,
378  consumer.getMatchingIndexingMap(fusedOperand));
379  fusedIndexMaps.push_back(map);
380  }
381  // 5. Remaining consumer's input operands/maps (drop past index
382  // `consumerIdx`).
383  for (OpOperand *opOperand :
384  llvm::make_range(std::next(it), consumerInputs.end())) {
385  fusedInputOperands.push_back(opOperand->get());
386  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
387  }
388 
389  // 6. Collect all of the producer outputs.
390  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
391  if (!preservedProducerResults.count(opOperand.index()))
392  continue;
393 
394  fusedOutputOperands.push_back(opOperand.value().get());
396  &opOperand.value(), producerResultIndexMap,
397  consumer.getMatchingIndexingMap(fusedOperand));
398  fusedIndexMaps.push_back(map);
399  fusedResultTypes.push_back(opOperand.value().get().getType());
400  }
401 
402  // 7. All of consumer's output operands (skip operands: added by the builder).
403  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
404  fusedOutputOperands.push_back(opOperand.get());
405  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
406  Type resultType = opOperand.get().getType();
407  if (!isa<MemRefType>(resultType))
408  fusedResultTypes.push_back(resultType);
409  }
410 
411  // Generate the fused op.
412  auto fusedOp = rewriter.create<GenericOp>(
413  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
414  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
415  consumer.getIteratorTypes(),
416  /*doc=*/nullptr,
417  /*library_call=*/nullptr);
418  if (!fusedOp.getShapesToLoopsMap()) {
419  // Fused op has invalid indexing maps. Typically this means something is off
420  // in the input, but going ahead here would result in verification errors.
421  // So cleanup and abort.
422  rewriter.eraseOp(fusedOp);
423  return rewriter.notifyMatchFailure(
424  fusedOp, "fused op failed loop bound computation check");
425  }
426 
427  // Construct an AffineMap from consumer loops to producer loops.
428  // consumer loop -> tensor index
429  AffineMap consumerResultIndexMap =
430  consumer.getMatchingIndexingMap(fusedOperand);
431  // tensor index -> producer loop
432  AffineMap invProducerResultIndexMap =
433  inversePermutation(producerResultIndexMap);
434  assert(invProducerResultIndexMap &&
435  "expected producer result indexig map to be invertible");
436  // consumer loop -> producer loop
437  AffineMap consumerToProducerLoopsMap =
438  invProducerResultIndexMap.compose(consumerResultIndexMap);
439 
441  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
442  consumer.getNumLoops(), preservedProducerResults);
444  result.fusedOp = fusedOp;
445  int resultNum = 0;
446  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
447  if (preservedProducerResults.count(index))
448  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
449  for (auto consumerResult : consumer->getResults())
450  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
451  return result;
452 }
453 
454 namespace {
455 /// Patterns to fuse a generic op, with the producer of its operands.
456 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
457 public:
458  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
459  PatternBenefit benefit = 1)
460  : OpRewritePattern<GenericOp>(context, benefit),
461  controlFn(std::move(fun)) {}
462 
463  LogicalResult matchAndRewrite(GenericOp genericOp,
464  PatternRewriter &rewriter) const override {
465  // Find the first operand that is defined by another generic op on tensors.
466  for (OpOperand &opOperand : genericOp->getOpOperands()) {
467  if (!areElementwiseOpsFusable(&opOperand))
468  continue;
469  if (!controlFn(&opOperand))
470  continue;
471 
472  Operation *producer = opOperand.get().getDefiningOp();
473 
474  // Find the producer of the operand.
475  FailureOr<ElementwiseOpFusionResult> fusionResult =
476  fuseElementwiseOps(rewriter, &opOperand);
477  if (failed(fusionResult))
478  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
479 
480  // Perform the fusion.
481  for (auto [origVal, replacement] : fusionResult->replacements) {
482  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
483  // Only replace consumer uses.
484  return use.get().getDefiningOp() != producer;
485  });
486  }
487  rewriter.eraseOp(genericOp);
488  return success();
489  }
490  return failure();
491  }
492 
493 private:
494  ControlFusionFn controlFn;
495 };
496 } // namespace
497 
498 //===---------------------------------------------------------------------===//
499 // Methods and patterns that fuse reshape ops with elementwise operations by
500 // expanding the dimensionality of the elementwise operations.
501 //===---------------------------------------------------------------------===//
502 
503 /// Conditions for folding a structured linalg operation with a reshape op by
504 /// expanding the iteration space dimensionality for tensor operations. These
505 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
506 /// the following fusion pattern.
507 ///
508 /// Consider
509 ///
510 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
511 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
512 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
513 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
514 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
515 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
516 ///
517 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
518 /// is increased to match the result (operand) of the tensor.expand_shape.
519 /// The indexing_map of the fused tensor in the `linalgOp` and the
520 /// reassociation map helps compute the indexing maps of the modified op.
521 /// For the above example, based on the reassociation map it
522 /// can be concluded that
523 ///
524 /// - The loop used to access the first dimension of the fused tensor is split
525 /// into two.
526 /// - The loop used to access the second dimension of the fused tensor is kept
527 /// as is.
528 /// - The loop used to access the third dimension of the fused tensor is split
529 /// into three.
530 ///
531 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
532 /// op, then
533 ///
534 /// d0 -> e0, e1
535 /// d1 -> e2, e3, e4
536 /// d2 -> e5
537 ///
538 /// substituting this, the structured op can be rewritten as
539 ///
540 /// %d = linalg.generic ins(%0, %1 : )
541 /// indexing_maps =
542 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
543 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
544 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
545 ///
546 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
547 /// to make it consistent
548 ///
549 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
550 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
551 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
552 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
553 ///
554 /// The added reshapes are again expanding patterns, so they will get fused
555 /// with its producers if possible.
556 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
557  OpOperand *fusableOpOperand) {
558  // Is fusable only if:
559  // - All the indexing maps for operands and results are projected
560  // permutations.
561  // - The fused tensor is not a scalar.
562  // - All the loops for the reshaped operand are parallel loops.
563  SmallVector<utils::IteratorType> iteratorTypes =
564  linalgOp.getIteratorTypesArray();
565  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
566  return linalgOp.hasPureTensorSemantics() &&
567  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
568  [](Attribute attr) {
569  return cast<AffineMapAttr>(attr)
570  .getValue()
571  .isProjectedPermutation();
572  }) &&
573  operandMap.getNumResults() > 0 &&
574  llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
575  return isParallelIterator(
576  iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
577  });
578 }
579 
580 namespace {
581 /// Information needed to expand a generic operation to fold the reshape with
582 /// it.
583 class ExpansionInfo {
584 public:
585  // Computes the mapping from original dimensions of the op to the dimensions
586  // of the expanded op given the `indexingMap` of the fused operand/result of
587  // the generic op, the `reassocationMaps` of the reshape op and the shape of
588  // the expanded op.
589  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
590  ArrayRef<AffineMap> reassociationMaps,
591  ArrayRef<int64_t> expandedShape,
592  ArrayRef<int64_t> collapsedShape,
593  PatternRewriter &rewriter);
594  unsigned getOrigOpNumDims() const { return reassociation.size(); }
595  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
596  ReassociationIndicesRef getExpandedDims(unsigned i) const {
597  return reassociation[i];
598  }
599  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
600  return expandedShapeMap[i];
601  }
602  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
603 
604 private:
605  /// Reassociation from the dimensions in the original operation to the
606  /// dimension of the expanded operation.
607  SmallVector<ReassociationIndices> reassociation;
608  /// Mapping from extent of loops in the original operation, to the extent of
609  /// loops in the expanded operation.
610  SmallVector<SmallVector<int64_t>> expandedShapeMap;
611  /// Extent of the loop in the original operation.
612  SmallVector<int64_t> originalLoopExtent;
613  unsigned expandedOpNumDims;
614 };
615 } // namespace
616 
617 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
618  OpOperand *fusableOpOperand,
619  ArrayRef<AffineMap> reassociationMaps,
620  ArrayRef<int64_t> expandedShape,
621  ArrayRef<int64_t> collapsedShape,
622  PatternRewriter &rewriter) {
623  if (reassociationMaps.empty())
624  return failure();
625  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
626 
627  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
628  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
629 
630  reassociation.clear();
631  expandedShapeMap.clear();
632  // Compute the number of dimension in the expanded op that correspond to each
633  // dimension of the original op.
634  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
635  expandedShapeMap.resize(fusedIndexMap.getNumDims());
636  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
637  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
638  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
639  numExpandedDims[pos] = foldedDims.getNumResults();
640  ArrayRef<int64_t> shape =
641  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
642  expandedShapeMap[pos].assign(shape.begin(), shape.end());
643  }
644  // The remaining dimensions remain the same.
645  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
646  if (expandedShapeMap[i].empty())
647  expandedShapeMap[i] = {originalLoopExtent[i]};
648 
649  // Compute reassociation map from the original op to the expanded op.
650  unsigned sum = 0;
651  reassociation.reserve(fusedIndexMap.getNumDims());
652  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
653  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
654  reassociation.emplace_back(seq.begin(), seq.end());
655  sum += numFoldedDim.value();
656  }
657  expandedOpNumDims = sum;
658  return success();
659 }
660 
661 /// Expanding the body of a linalg operation requires adaptations of the
662 /// accessed loop indices. Specifically, access of indices in the original
663 /// operation need to be replaced with linearizations of indices in the expanded
664 /// op. That requires the shape of the expanded dimensions to be static (at
665 /// least all but the most significant). For now check that these are all
666 /// statically sized. Note that this could be extended to handle dynamic case,
667 /// but the implementation below uses `affine.apply` which seems to have issues
668 /// when the shapes are not static.
669 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
670  const ExpansionInfo &expansionInfo,
671  PatternRewriter &rewriter) {
672  if (!linalgOp.hasIndexSemantics())
673  return success();
674  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
675  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
676  if (expandedShape.size() == 1)
677  continue;
678  for (int64_t shape : expandedShape.drop_front()) {
679  if (ShapedType::isDynamic(shape)) {
680  return rewriter.notifyMatchFailure(
681  linalgOp, "cannot expand due to index semantics and dynamic dims");
682  }
683  }
684  }
685  return success();
686 }
687 
688 /// Return the indexing map to use in the expanded op for a given the
689 /// `indexingMap` of the original operation.
690 static AffineMap
692  const ExpansionInfo &expansionInfo) {
693  SmallVector<AffineExpr> newExprs;
694  for (AffineExpr expr : indexingMap.getResults()) {
695  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
696  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
697  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
698  return builder.getAffineDimExpr(static_cast<unsigned>(v));
699  }));
700  newExprs.append(expandedExprs.begin(), expandedExprs.end());
701  }
702  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
703  indexingMap.getNumSymbols(), newExprs,
704  builder.getContext());
705 }
706 
707 /// Return the type of the operand/result to use in the expanded op given the
708 /// type in the original op.
709 static RankedTensorType getExpandedType(RankedTensorType originalType,
710  AffineMap indexingMap,
711  const ExpansionInfo &expansionInfo) {
712  SmallVector<int64_t> expandedShape;
713  for (AffineExpr expr : indexingMap.getResults()) {
714  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
715  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
716  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
717  }
718  return RankedTensorType::get(expandedShape, originalType.getElementType());
719 }
720 
721 /// Returns the reassociation maps to use in the `tensor.expand_shape`
722 /// operation to convert the operands of the original operation to operands of
723 /// the expanded operation. The same method is used to compute the
724 /// `tensor.collapse_shape` used to collapse the result of the expanded
725 /// op to get the value that can replace all uses of the results of the original
726 /// op.
729  const ExpansionInfo &expansionInfo) {
730  SmallVector<ReassociationIndices> reassociation;
731  unsigned numReshapeDims = 0;
732  for (AffineExpr expr : indexingMap.getResults()) {
733  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
734  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
735  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
736  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
737  reassociation.emplace_back(std::move(indices));
738  numReshapeDims += numExpandedDims;
739  }
740  return reassociation;
741 }
742 
743 /// Update the body of an expanded linalg operation having index semantics. The
744 /// indices of the original operation need to be recovered by linearizing the
745 /// indices of the correspoding dimensions of the expanded operation. For now it
746 /// is assumed that the shapes of the expanded operation needed for
747 /// linearization are static.
749  Location loc, Region &fusedRegion,
750  const ExpansionInfo &expansionInfo) {
751  // Replace the original indices by the linearization of the expanded indices.
752  for (IndexOp indexOp :
753  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
754  ArrayRef<int64_t> expandedDims =
755  expansionInfo.getExpandedDims(indexOp.getDim());
756  assert(!expandedDims.empty() && "expected valid expansion info");
757 
758  // Skip index operations that are not affected by the expansion.
759  if (expandedDims.size() == 1 &&
760  expandedDims.front() == (int64_t)indexOp.getDim())
761  continue;
762 
763  // Linearize the expanded indices of the original index dimension.
764  OpBuilder::InsertionGuard guard(rewriter);
765  rewriter.setInsertionPointAfter(indexOp);
766  ArrayRef<int64_t> expandedDimsShape =
767  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
768  SmallVector<Value> expandedIndices;
769  expandedIndices.reserve(expandedDims.size() - 1);
770  llvm::transform(
771  expandedDims.drop_front(), std::back_inserter(expandedIndices),
772  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
773  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
774  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
775  assert(!ShapedType::isDynamic(std::get<0>(it)));
776  AffineExpr idx, acc;
777  bindDims(rewriter.getContext(), idx, acc);
778  newIndex = rewriter.create<affine::AffineApplyOp>(
779  indexOp.getLoc(), idx + acc * std::get<0>(it),
780  ValueRange{std::get<1>(it), newIndex});
781  }
782  rewriter.replaceOp(indexOp, newIndex);
783  }
784 }
785 
786 /// Checks if a single dynamic dimension expanded into multiple dynamic
787 /// dimensions.
788 static LogicalResult
789 validateDynamicDimExpansion(LinalgOp linalgOp,
790  const ExpansionInfo &expansionInfo,
791  PatternRewriter &rewriter) {
792  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
793  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
794  if (expandedShape.size() == 1)
795  continue;
796  bool foundDynamic = false;
797  for (int64_t shape : expandedShape) {
798  if (!ShapedType::isDynamic(shape))
799  continue;
800  if (foundDynamic) {
801  return rewriter.notifyMatchFailure(
802  linalgOp, "cannot infer expanded shape with multiple dynamic "
803  "dims in the same reassociation group");
804  }
805  foundDynamic = true;
806  }
807  }
808  return success();
809 }
810 
811 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
812 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
813 /// that those conditions have been satisfied.
814 static std::optional<SmallVector<Value>>
815 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
816  OpOperand *fusableOpOperand,
817  PatternRewriter &rewriter) {
818  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
819  "preconditions for fuse operation failed");
820 
821  Location loc = linalgOp.getLoc();
822  // Check if reshape is expanding or collapsing.
823  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
824  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
825  bool isExpanding = (expandingReshapeOp != nullptr);
826  RankedTensorType expandedType = isExpanding
827  ? expandingReshapeOp.getResultType()
828  : collapsingReshapeOp.getSrcType();
829  RankedTensorType collapsedType = isExpanding
830  ? expandingReshapeOp.getSrcType()
831  : collapsingReshapeOp.getResultType();
832 
833  ExpansionInfo expansionInfo;
834  if (failed(expansionInfo.compute(
835  linalgOp, fusableOpOperand,
836  isExpanding ? expandingReshapeOp.getReassociationMaps()
837  : collapsingReshapeOp.getReassociationMaps(),
838  expandedType.getShape(), collapsedType.getShape(), rewriter)))
839  return std::nullopt;
840 
841  // TODO: With the support of multiple dynamic dims expansion in
842  // tensor.expand_shape op, this case can be handled.
843  if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
844  return std::nullopt;
845 
846  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
847  return std::nullopt;
848 
849  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
850  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
851  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
852  }));
853 
854  // Set insertion point to the generic op.
855  OpBuilder::InsertionGuard g(rewriter);
856  rewriter.setInsertionPoint(linalgOp);
857 
858  SmallVector<Value> expandedOpOperands;
859  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
860  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
861  if (opOperand == fusableOpOperand) {
862  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
863  : collapsingReshapeOp.getSrc());
864  continue;
865  }
866  if (auto opOperandType =
867  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
868  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
869  RankedTensorType expandedOperandType =
870  getExpandedType(opOperandType, indexingMap, expansionInfo);
871  if (expandedOperandType != opOperand->get().getType()) {
872  // Reshape the operand to get the right type.
873  SmallVector<ReassociationIndices> reassociation =
874  getReassociationForExpansion(indexingMap, expansionInfo);
876  [&](const Twine &msg) {
877  return rewriter.notifyMatchFailure(linalgOp, msg);
878  },
879  opOperandType.getShape(), expandedOperandType.getShape(),
880  reassociation,
881  /*isExpandingReshape=*/true)))
882  return std::nullopt;
883  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
884  loc, expandedOperandType, opOperand->get(), reassociation));
885  continue;
886  }
887  }
888  expandedOpOperands.push_back(opOperand->get());
889  }
890 
891  SmallVector<Value> outputs;
892  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
893  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
894  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
895  RankedTensorType expandedOutputType =
896  getExpandedType(opOperandType, indexingMap, expansionInfo);
897  if (expandedOutputType != opOperand.get().getType()) {
898  SmallVector<ReassociationIndices> reassociation =
899  getReassociationForExpansion(indexingMap, expansionInfo);
901  [&](const Twine &msg) {
902  return rewriter.notifyMatchFailure(linalgOp, msg);
903  },
904  opOperandType.getShape(), expandedOutputType.getShape(),
905  reassociation,
906  /*isExpandingReshape=*/true)))
907  return std::nullopt;
908  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
909  loc, expandedOutputType, opOperand.get(), reassociation));
910  } else {
911  outputs.push_back(opOperand.get());
912  }
913  }
914 
915  // The iterator types of the expanded op are all parallel.
916  SmallVector<utils::IteratorType> iteratorTypes(
917  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
918  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
919  for (auto j : expansionInfo.getExpandedDims(i))
920  iteratorTypes[j] = type;
921 
922  TypeRange resultTypes = ValueRange(outputs).getTypes();
923  auto fusedOp =
924  rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
925  /*inputs=*/expandedOpOperands, outputs,
926  expandedOpIndexingMaps, iteratorTypes);
927  Region &fusedRegion = fusedOp->getRegion(0);
928  Region &originalRegion = linalgOp->getRegion(0);
929  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
930 
931  // Update the index accesses after the expansion.
932  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
933 
934  // Reshape the result values to their original shape if this is a collapsing
935  // reshape folded into its consumer.
936  SmallVector<Value> resultVals;
937  for (OpResult opResult : linalgOp->getOpResults()) {
938  int64_t resultNumber = opResult.getResultNumber();
939  if (resultTypes[resultNumber] != opResult.getType()) {
940  SmallVector<ReassociationIndices> reassociation =
942  linalgOp.getMatchingIndexingMap(
943  linalgOp.getDpsInitOperand(resultNumber)),
944  expansionInfo);
945  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
946  linalgOp.getLoc(), opResult.getType(),
947  fusedOp->getResult(resultNumber), reassociation));
948  } else {
949  resultVals.push_back(fusedOp->getResult(resultNumber));
950  }
951  }
952  // Assuming a single result.
953  return resultVals;
954 }
955 
956 namespace {
957 
958 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
959 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
960 /// in the consumer is expanded.
961 class FoldWithProducerReshapeOpByExpansion
962  : public OpInterfaceRewritePattern<LinalgOp> {
963 public:
964  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
965  ControlFusionFn foldReshapes,
966  PatternBenefit benefit = 1)
967  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
968  controlFoldingReshapes(std::move(foldReshapes)) {}
969 
970  LogicalResult matchAndRewrite(LinalgOp linalgOp,
971  PatternRewriter &rewriter) const override {
972  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
973  tensor::CollapseShapeOp reshapeOp =
974  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
975  if (!reshapeOp)
976  continue;
977  // Fold only if
978  // - The tensor reshape op is folding.
979  // - All constraints of fusing with reshape by expansion are met.
980  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
981  (!controlFoldingReshapes(opOperand)))
982  continue;
983 
984  std::optional<SmallVector<Value>> replacementValues =
985  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
986  if (!replacementValues)
987  return failure();
988  rewriter.replaceOp(linalgOp, *replacementValues);
989  return success();
990  }
991  return failure();
992  }
993 
994 private:
995  ControlFusionFn controlFoldingReshapes;
996 };
997 
998 class FoldPadWithProducerReshapeOpByExpansion
999  : public OpRewritePattern<tensor::PadOp> {
1000 public:
1001  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1002  ControlFusionFn foldReshapes,
1003  PatternBenefit benefit = 1)
1004  : OpRewritePattern<tensor::PadOp>(context, benefit),
1005  controlFoldingReshapes(std::move(foldReshapes)) {}
1006 
1007  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1008  PatternRewriter &rewriter) const override {
1009  tensor::CollapseShapeOp reshapeOp =
1010  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1011  if (!reshapeOp)
1012  return failure();
1013  if (!reshapeOp->hasOneUse())
1014  return failure();
1015 
1016  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1017  return rewriter.notifyMatchFailure(padOp,
1018  "fusion blocked by control function");
1019  }
1020 
1021  ArrayRef<int64_t> low = padOp.getStaticLow();
1022  ArrayRef<int64_t> high = padOp.getStaticHigh();
1023  SmallVector<ReassociationIndices> reassociations =
1024  reshapeOp.getReassociationIndices();
1025 
1026  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1027  if (reInd.size() != 1 && (l != 0 || h != 0))
1028  return failure();
1029  }
1030 
1031  SmallVector<OpFoldResult> newLow, newHigh;
1032  RankedTensorType expandedType = reshapeOp.getSrcType();
1033  RankedTensorType paddedType = padOp.getResultType();
1034  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1035  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1036  if (reInd.size() == 1) {
1037  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1038  }
1039  for (size_t i = 0; i < reInd.size(); ++i) {
1040  newLow.push_back(padOp.getMixedLowPad()[idx]);
1041  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1042  }
1043  }
1044 
1045  Location loc = padOp->getLoc();
1046  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1047  auto newPadOp = rewriter.create<tensor::PadOp>(
1048  loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1049  padOp.getConstantPaddingValue(), padOp.getNofold());
1050 
1051  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1052  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1053 
1054  return success();
1055  }
1056 
1057 private:
1058  ControlFusionFn controlFoldingReshapes;
1059 };
1060 
1061 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1062 /// by expanding the dimensionality of the loop in the producer op.
1063 struct FoldReshapeWithGenericOpByExpansion
1064  : public OpRewritePattern<tensor::ExpandShapeOp> {
1065 
1066  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1067  ControlFusionFn foldReshapes,
1068  PatternBenefit benefit = 1)
1069  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1070  controlFoldingReshapes(std::move(foldReshapes)) {}
1071 
1072  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1073  PatternRewriter &rewriter) const override {
1074  // Fold only if all constraints of fusing with reshape by expansion are met.
1075  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1076  if (!producerResult) {
1077  return rewriter.notifyMatchFailure(reshapeOp,
1078  "source not produced by an operation");
1079  }
1080 
1081  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1082  if (!producer) {
1083  return rewriter.notifyMatchFailure(reshapeOp,
1084  "producer not a generic op");
1085  }
1086 
1088  producer,
1089  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1090  return rewriter.notifyMatchFailure(
1091  reshapeOp, "failed preconditions of fusion with producer generic op");
1092  }
1093 
1094  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1095  return rewriter.notifyMatchFailure(reshapeOp,
1096  "fusion blocked by control function");
1097  }
1098 
1099  std::optional<SmallVector<Value>> replacementValues =
1101  producer, reshapeOp,
1102  producer.getDpsInitOperand(producerResult.getResultNumber()),
1103  rewriter);
1104  if (!replacementValues) {
1105  return rewriter.notifyMatchFailure(reshapeOp,
1106  "fusion by expansion failed");
1107  }
1108 
1109  // Find the replacement for the reshape op. Since the replacements have the
1110  // same type as the returns of the original generic op, the consumer reshape
1111  // op can be replaced by the source of the collapse_shape op that defines
1112  // the replacement.
1113  Value reshapeReplacement =
1114  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1115  .getResultNumber()];
1116  if (auto collapseOp =
1117  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1118  reshapeReplacement = collapseOp.getSrc();
1119  }
1120  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1121  rewriter.replaceOp(producer, *replacementValues);
1122  return success();
1123  }
1124 
1125 private:
1126  ControlFusionFn controlFoldingReshapes;
1127 };
1128 } // namespace
1129 
1130 //===---------------------------------------------------------------------===//
1131 // Methods and patterns to fuse reshape with linalg.generic operations by
1132 // contraction of dimensions.
1133 //===---------------------------------------------------------------------===//
1134 
1135 /// For a given list of indices in the range of the `indexingMap` that are
1136 /// folded, return the indices of the corresponding domain. Return
1137 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1138 /// reassociation are distinct.
1139 static ReassociationIndices
1141  ReassociationIndicesRef rangeReassociation) {
1142  assert(indexingMap.isProjectedPermutation() &&
1143  "expected projected permutation");
1144 
1145  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1146  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1147  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1148  }));
1149  // The projected permutation semantics ensures that there is no repetition of
1150  // the domain indices.
1151  return domainReassociation;
1152 }
1153 
1154 /// For a given `dimSequence`, check if the sequence is conserved in the
1155 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1156 /// Non-existence of the sequence returns true as well.
1158  ReassociationIndicesRef dimSequence) {
1159  assert(!dimSequence.empty() &&
1160  "expected non-empty list for dimension sequence");
1161  assert(indexingMap.isProjectedPermutation() &&
1162  "expected indexing map to be projected permutation");
1163 
1164  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1165  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1166 
1167  unsigned dimSequenceStart = dimSequence[0];
1168  for (const auto &expr : enumerate(indexingMap.getResults())) {
1169  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1170  // 1. Check if this start of the sequence.
1171  if (dimInMapStart == dimSequenceStart) {
1172  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1173  return false;
1174  // 1a. Check if sequence is preserved.
1175  for (const auto &dimInSequence : enumerate(dimSequence)) {
1176  unsigned dimInMap =
1177  cast<AffineDimExpr>(
1178  indexingMap.getResult(expr.index() + dimInSequence.index()))
1179  .getPosition();
1180  if (dimInMap != dimInSequence.value())
1181  return false;
1182  }
1183  // Found the sequence. Projected permutation
1184  // enforces that all AffineDimExprs in the result are unique, so no
1185  // further checks are needed.
1186  return true;
1187  }
1188  // 2. If position in the expr (which is of type AffineDimExpr) is part
1189  // of sequence, return false here. This implies the entire sequence does not
1190  // exist in the indexing map.
1191  if (sequenceElements.count(dimInMapStart))
1192  return false;
1193  }
1194  // 3. No element of sequence found. Return true.
1195  return true;
1196 }
1197 
1200  return llvm::all_of(maps, [&](AffineMap map) {
1201  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1202  return isDimSequencePreserved(map, dimSequence);
1203  });
1204  });
1205 }
1206 
1207 // Return the list of dimensions of the iteration domain that can be
1208 // collapsed to allow for fusion with the a producer that is an expand_shape
1209 // operation. If all dimensions created by expansion can be collapsed in the
1210 // iteration space then the reshape is defunct.
1211 //
1212 // Example:
1213 //
1214 // ```mlir
1215 // #map = affine_map<(d0, d1) -> (d0, d1)>
1216 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1217 // %2 = tensor.empty [..] : tensor<?x4xf32>
1218 // %3 = linalg.generic {
1219 // indexing_maps = [#map, #map],
1220 // iterator_types = ["parallel" ,"parallel"]}
1221 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1222 // ```
1223 //
1224 // can be fused by collapsing the dimensions of the iteration space.
1225 //
1226 // ```mlir
1227 // #map = affine_map<(d0) -> (d0)>
1228 // %2 = tensor.empty [..] : tensor<?xf32>
1229 // %3 = linalg.generic {
1230 // indexing_maps = [#map, #map],
1231 // iterator_types = ["parallel"]}
1232 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1233 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1234 // ```
1235 //
1236 // In the following example,
1237 //
1238 // ```mlir
1239 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1240 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1241 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1242 // %2 = tensor.empty [..] : tensor<4x?xf32>
1243 // %2 = linalg.generic {
1244 // indexing_maps = [#map0, #map1],
1245 // iterator_types = ["parallel" ,"parallel"]}
1246 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1247 // ```
1248 //
1249 // the reshape cannot be fused with the generic op by collapsing the op
1250 // dimensions since the indexing maps will have to contain mods and divs
1251 // to preserve the accesses pattern. When no dimensions of the iteration
1252 // space are collapsable and empty vector is returned.
1254 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1255  ArrayRef<ReassociationIndices> reassociation) {
1256  // Some basic checks for this fusion to be valid.
1257  if (!genericOp.hasPureTensorSemantics())
1258  return {};
1259 
1260  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1261  return map.isProjectedPermutation();
1262  })) {
1263  return {};
1264  }
1265 
1266  // Compute all the loops with the reduction iterator types.
1267  SmallVector<unsigned> reductionDims;
1268  genericOp.getReductionDims(reductionDims);
1269 
1270  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1271  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1272  auto iteratorTypes = genericOp.getIteratorTypesArray();
1273  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1274  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1275  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1276 
1277  // Ignore dims that are not folded.
1278  if (foldedRangeDims.size() == 1)
1279  continue;
1280 
1281  ReassociationIndices foldedIterationSpaceDims =
1282  getDomainReassociation(indexingMap, foldedRangeDims);
1283 
1284  // Check that the folded iteration dims do not contain already processed
1285  // dims.
1286  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1287  return processedIterationDims.count(dim);
1288  }))
1289  continue;
1290 
1291  // Check that all folded iterator types are all parallel or all reductions.
1292  utils::IteratorType startIteratorType =
1293  iteratorTypes[foldedIterationSpaceDims[0]];
1294  if (!isParallelIterator(startIteratorType) &&
1295  !isReductionIterator(startIteratorType))
1296  continue;
1297  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1298  return iteratorTypes[dim] != startIteratorType;
1299  }))
1300  continue;
1301 
1302  // If the folded dimensions correspond to a "reduction" iterator type,
1303  // the folded dimensions need to be "in-order". Strictly speaking this is
1304  // not necessary, for reductions that are associative and commutative, but
1305  // using a more strict definition of reduction for now.
1306  if (isReductionIterator(startIteratorType)) {
1307  bool isContiguous = false;
1308  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1309  // Move window in `reductionDims` to start of the folded iteration dims.
1310  if (startDim.value() != foldedIterationSpaceDims[0])
1311  continue;
1312  // If sizes doesnt match, trivial not contiguous. This condition should
1313  // not be hit.
1314  if (startDim.index() + foldedIterationSpaceDims.size() >
1315  reductionDims.size())
1316  break;
1317  // Check that the contiguity is maintained.
1318  isContiguous = true;
1319  for (const auto &foldedDim :
1320  llvm::enumerate(foldedIterationSpaceDims)) {
1321  if (reductionDims[foldedDim.index() + startDim.index()] !=
1322  foldedDim.value()) {
1323  isContiguous = false;
1324  break;
1325  }
1326  }
1327  break;
1328  }
1329  if (!isContiguous)
1330  continue;
1331  }
1332 
1333  // Check that the sequence is preserved in all indexing maps.
1334  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1335  [&](AffineMap indexingMap) {
1336  return !isDimSequencePreserved(indexingMap,
1337  foldedIterationSpaceDims);
1338  }))
1339  continue;
1340 
1341  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1342  foldedIterationSpaceDims.end());
1343  iterationSpaceReassociation.emplace_back(
1344  std::move(foldedIterationSpaceDims));
1345  }
1346 
1347  return iterationSpaceReassociation;
1348 }
1349 
1350 /// Helper class to carry state while collapsing the `linalg.generic` op.
1351 namespace {
1352 class CollapsingInfo {
1353 public:
1354  LogicalResult initialize(unsigned origNumLoops,
1355  ArrayRef<ReassociationIndices> foldedIterationDims) {
1356  llvm::SmallDenseSet<int64_t, 4> processedDims;
1357  // Find all the dims that are folded.
1358  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1359  if (foldedIterationDim.empty())
1360  continue;
1361  // If the folded dims contain dims already folded, that's illegal
1362  // specification. Repetition within a list is also illegal.
1363  for (auto dim : foldedIterationDim) {
1364  if (dim >= origNumLoops)
1365  return failure();
1366  if (processedDims.count(dim))
1367  return failure();
1368  processedDims.insert(dim);
1369  }
1370  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1371  foldedIterationDim.end());
1372  }
1373  if (processedDims.size() > origNumLoops)
1374  return failure();
1375 
1376  // Add all the preserved dims of the original op as single
1377  // elements to `collapsedOpToOrigOpIterationDim`.
1378  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1379  if (processedDims.count(dim))
1380  continue;
1381  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1382  }
1383 
1384  llvm::sort(collapsedOpToOrigOpIterationDim,
1386  return lhs[0] < rhs[0];
1387  });
1388  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1389  for (const auto &foldedDims :
1390  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1391  for (const auto &dim : enumerate(foldedDims.value()))
1392  origOpToCollapsedOpIterationDim[dim.value()] =
1393  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1394  }
1395  return success();
1396  }
1397 
1398  /// Return mapping from collapsed loop domain to original loop domain.
1399  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1400  return collapsedOpToOrigOpIterationDim;
1401  }
1402 
1403  /// Return mapping from original loop domain to collapsed loop domain. The
1404  /// mapping is a pair. First value is the dimension in the collapsed loop that
1405  /// the original loop is mapped to. Second is the relative position in folded
1406  /// list of this domain. For example if the original loop domain is 3D, and
1407  /// the collapsed loop domain is folding all of it, i.e.
1408  ///
1409  /// ```
1410  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1411  /// ```
1412  ///
1413  /// then
1414  ///
1415  /// ```
1416  /// origOpToCollapsedOpMapping[0] = {0, 0};
1417  /// origOpToCollapsedOpMapping[1] = {0, 1};
1418  /// origOpToCollapsedOpMapping[2] = {0, 2};
1419  /// origOpToCollapsedOpMapping[3] = {1, 0};
1420  /// origOpToCollapsedOpMapping[4] = {1, 1};
1421  /// ```
1422  ///
1423  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1424  return origOpToCollapsedOpIterationDim;
1425  }
1426 
1427  /// Return the collapsed op iteration domain rank.
1428  unsigned getCollapsedOpIterationRank() const {
1429  return collapsedOpToOrigOpIterationDim.size();
1430  }
1431 
1432 private:
1433  /// Map from the iteration domain index in collapsed op to the iteration
1434  /// domain indices in the original op.
1435  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1436 
1437  /// Map from iteration domain index in the original op to the iteration domain
1438  /// index in the collapsed op.
1439  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1440 };
1441 } // namespace
1442 
1443 /// Get the iterator types for the collapsed operation given the original
1444 /// iterator types and collapsed dimensions.
1447  const CollapsingInfo &collapsingInfo) {
1448  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1449  for (ReassociationIndicesRef foldedIterDims :
1450  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1451  assert(!foldedIterDims.empty() &&
1452  "reassociation indices expected to have non-empty sets");
1453  // Just pick the iterator type of the first folded dim. Pre-condition checks
1454  // expected to have checked that iterator types of all folded dimensions are
1455  // the same.
1456  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1457  }
1458  return collapsedIteratorTypes;
1459 }
1460 
1461 /// Compute the indexing map in the collapsed op that corresponds to the given
1462 /// `indexingMap` of the original operation.
1463 static AffineMap
1465  const CollapsingInfo &collapsingInfo) {
1466  MLIRContext *context = indexingMap.getContext();
1467  assert(indexingMap.isProjectedPermutation() &&
1468  "expected indexing map to be projected permutation");
1469  SmallVector<AffineExpr> resultExprs;
1470  auto origOpToCollapsedOpMapping =
1471  collapsingInfo.getOrigOpToCollapsedOpMapping();
1472  for (auto expr : indexingMap.getResults()) {
1473  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1474  // If the dim is not the first of the collapsed dim, do nothing.
1475  if (origOpToCollapsedOpMapping[dim].second != 0)
1476  continue;
1477  // The next n-dims are guaranteed to be collapsed. So just use the
1478  // iteration dimension of the collapsed op.
1479  resultExprs.push_back(
1480  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1481  }
1482  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1483  resultExprs, context);
1484 }
1485 
1486 /// Return the `reassociation` indices to use to collapse the operand when the
1487 /// iteration space of a generic op is collapsed.
1490  const CollapsingInfo &collapsingInfo) {
1491  unsigned counter = 0;
1492  SmallVector<ReassociationIndices> operandReassociation;
1493  auto origOpToCollapsedOpMapping =
1494  collapsingInfo.getOrigOpToCollapsedOpMapping();
1495  auto collapsedOpToOrigOpMapping =
1496  collapsingInfo.getCollapsedOpToOrigOpMapping();
1497  while (counter < indexingMap.getNumResults()) {
1498  unsigned dim =
1499  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1500  // This is the start of a collapsed dimensions of the iteration that
1501  // is gauranteed to be preserved in the indexing map. The number of folded
1502  // dims is obtained from the collapsed op to original op mapping.
1503  unsigned numFoldedDims =
1504  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1505  .size();
1506  if (origOpToCollapsedOpMapping[dim].second == 0) {
1507  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1508  operandReassociation.emplace_back(range.begin(), range.end());
1509  }
1510  counter += numFoldedDims;
1511  }
1512  return operandReassociation;
1513 }
1514 
1515 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1516 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1517  OpOperand *opOperand,
1518  const CollapsingInfo &collapsingInfo,
1519  OpBuilder &builder) {
1520  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1521  SmallVector<ReassociationIndices> operandReassociation =
1522  getOperandReassociation(indexingMap, collapsingInfo);
1523 
1524  // If the number of entries in the reassociation for the operand is same as
1525  // the number of results of the indexing map, then nothing to do for this
1526  // operand.
1527  Value operand = opOperand->get();
1528  if (operandReassociation.size() == indexingMap.getNumResults())
1529  return operand;
1530 
1531  // Insert a reshape to collapse the dimensions.
1532  if (isa<MemRefType>(operand.getType())) {
1533  return builder
1534  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1535  .getResult();
1536  }
1537  return builder
1538  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1539  .getResult();
1540 }
1541 
1542 /// Modify the `linalg.index` operations in the original generic op, to its
1543 /// value in the collapsed operation.
1545  const CollapsingInfo &collapsingInfo,
1546  ValueRange loopRange,
1547  RewriterBase &rewriter) {
1548  OpBuilder::InsertionGuard g(rewriter);
1549  rewriter.setInsertionPointToStart(block);
1550 
1551  // Collect all the original index ops.
1552  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1553 
1554  // For each folded dimension list resolve the original induction variable
1555  // values in terms of the folded dimension induction variable.
1556  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1557  // can be inverted to
1558  // i2 = i_{folded} % d2
1559  // i1 = (i_{folded} / d2) % d1
1560  // i0 = i_{folded} / (d1 * d2)
1561  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1562  for (auto foldedDims :
1563  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1564  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1565  Value newIndexVal =
1566  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1567  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1568  indexReplacementVals[dim] =
1569  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1570  newIndexVal =
1571  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1572  }
1573  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1574  }
1575 
1576  for (auto indexOp : indexOps) {
1577  auto dim = indexOp.getDim();
1578  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1579  }
1580 }
1581 
1583  const CollapsingInfo &collapsingInfo,
1584  RewriterBase &rewriter,
1585  SmallVectorImpl<Value> &inputOperands,
1586  SmallVectorImpl<Value> &outputOperands,
1587  SmallVectorImpl<Type> &resultTypes) {
1588  Location loc = op->getLoc();
1589  inputOperands =
1590  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1591  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1592  rewriter);
1593  });
1594 
1595  // Get the output operands and result types.
1596  resultTypes.reserve(op.getNumDpsInits());
1597  outputOperands.reserve(op.getNumDpsInits());
1598  for (OpOperand &output : op.getDpsInitsMutable()) {
1599  Value newOutput =
1600  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1601  outputOperands.push_back(newOutput);
1602  // If the op has "buffer semantics", then the init operands are ranked
1603  // memrefs and the op has no results.
1604  if (!op.hasPureBufferSemantics())
1605  resultTypes.push_back(newOutput.getType());
1606  }
1607 }
1608 
1609 /// Clone a `LinalgOp` to a collapsed version of same name
1610 template <typename OpTy>
1611 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1612  const CollapsingInfo &collapsingInfo) {
1613  return nullptr;
1614 }
1615 
1616 /// Collapse any `LinalgOp` that does not require any specialization such as
1617 /// indexing_maps, iterator_types, etc.
1618 template <>
1619 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1620  const CollapsingInfo &collapsingInfo) {
1621  SmallVector<Value> inputOperands, outputOperands;
1622  SmallVector<Type> resultTypes;
1623  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1624  outputOperands, resultTypes);
1625 
1626  return clone(
1627  rewriter, origOp, resultTypes,
1628  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1629 }
1630 
1631 /// Collapse a `GenericOp`
1632 template <>
1634  GenericOp origOp,
1635  const CollapsingInfo &collapsingInfo) {
1636  SmallVector<Value> inputOperands, outputOperands;
1637  SmallVector<Type> resultTypes;
1638  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1639  outputOperands, resultTypes);
1640  SmallVector<AffineMap> indexingMaps(
1641  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1642  return getCollapsedOpIndexingMap(map, collapsingInfo);
1643  }));
1644 
1646  origOp.getIteratorTypesArray(), collapsingInfo));
1647 
1648  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1649  origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1650  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1651  Block *origOpBlock = &origOp->getRegion(0).front();
1652  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1653  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1654  collapsedOpBlock->getArguments());
1655  return collapsedOp;
1656 }
1657 
1658 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1659  RewriterBase &rewriter) {
1660  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1661  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1662  } else {
1663  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1664  }
1665 }
1666 
1667 /// Implementation of fusion with reshape operation by collapsing dimensions.
1668 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1669  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1670  RewriterBase &rewriter) {
1671  // Bail on trivial no-op cases.
1672  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1673  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1674  return foldedDims.size() <= 1;
1675  }))
1676  return failure();
1677 
1678  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1679  if (hasPureBufferSemantics &&
1680  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1681  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1682  if (!memRefToCollapse)
1683  return true;
1684 
1685  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1686  memRefToCollapse, foldedIterationDims);
1687  }))
1688  return rewriter.notifyMatchFailure(op,
1689  "memref is not guaranteed collapsible");
1690 
1691  CollapsingInfo collapsingInfo;
1692  if (failed(
1693  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1694  return rewriter.notifyMatchFailure(
1695  op, "illegal to collapse specified dimensions");
1696  }
1697 
1698  // Bail on non-canonical ranges.
1699  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1700  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1701  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1702  return cast<IntegerAttr>(attr).getInt() == value;
1703  llvm::APInt actual;
1704  return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1705  actual.getSExtValue() == value;
1706  };
1707  if (!llvm::all_of(loopRanges, [&](Range range) {
1708  return opFoldIsConstantValue(range.offset, 0) &&
1709  opFoldIsConstantValue(range.stride, 1);
1710  })) {
1711  return rewriter.notifyMatchFailure(
1712  op, "expected all loop ranges to have zero start and unit stride");
1713  }
1714 
1715  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1716 
1717  Location loc = op->getLoc();
1718  if (collapsedOp.hasIndexSemantics()) {
1719  // Collect the loop range of the generic op.
1720  OpBuilder::InsertionGuard g(rewriter);
1721  rewriter.setInsertionPoint(collapsedOp);
1722  SmallVector<Value> loopBound =
1723  llvm::map_to_vector(loopRanges, [&](Range range) {
1724  return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1725  });
1726  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1727  collapsingInfo, loopBound, rewriter);
1728  }
1729 
1730  // Insert expanding reshape for the result to get back the original result
1731  // type.
1732  SmallVector<Value> results;
1733  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1734  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1735  auto originalResultType =
1736  cast<ShapedType>(originalResult.value().getType());
1737  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1738  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1739  AffineMap indexingMap =
1740  op.getIndexingMapMatchingResult(originalResult.value());
1741  SmallVector<ReassociationIndices> reassociation =
1742  getOperandReassociation(indexingMap, collapsingInfo);
1743  Value result;
1744  if (isa<MemRefType>(collapsedOpResult.getType())) {
1745  MemRefType expandShapeResultType = MemRefType::get(
1746  originalResultType.getShape(), originalResultType.getElementType());
1747  result = rewriter.create<memref::ExpandShapeOp>(
1748  loc, expandShapeResultType, collapsedOpResult, reassociation);
1749  } else {
1750  result = rewriter.create<tensor::ExpandShapeOp>(
1751  loc, originalResultType, collapsedOpResult, reassociation);
1752  }
1753  results.push_back(result);
1754  } else {
1755  results.push_back(collapsedOpResult);
1756  }
1757  }
1758  return CollapseResult{results, collapsedOp};
1759 }
1760 
1761 namespace {
1762 
1763 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1764 /// contracting dimensions of the loop.
1765 class FoldWithProducerReshapeOpByCollapsing
1766  : public OpRewritePattern<GenericOp> {
1767 public:
1768  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1769  ControlFusionFn foldReshapes,
1770  PatternBenefit benefit = 1)
1771  : OpRewritePattern<GenericOp>(context, benefit),
1772  controlFoldingReshapes(std::move(foldReshapes)) {}
1773 
1774  LogicalResult matchAndRewrite(GenericOp genericOp,
1775  PatternRewriter &rewriter) const override {
1776  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1777  tensor::ExpandShapeOp reshapeOp =
1778  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1779  if (!reshapeOp)
1780  continue;
1781 
1782  SmallVector<ReassociationIndices> collapsableIterationDims =
1783  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1784  reshapeOp.getReassociationIndices());
1785  if (collapsableIterationDims.empty() ||
1786  !controlFoldingReshapes(&opOperand)) {
1787  continue;
1788  }
1789 
1790  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1791  genericOp, collapsableIterationDims, rewriter);
1792  if (!collapseResult) {
1793  return rewriter.notifyMatchFailure(
1794  genericOp, "failed to do the fusion by collapsing transformation");
1795  }
1796 
1797  rewriter.replaceOp(genericOp, collapseResult->results);
1798  return success();
1799  }
1800  return failure();
1801  }
1802 
1803 private:
1804  ControlFusionFn controlFoldingReshapes;
1805 };
1806 
1807 class FoldPadWithProducerReshapeOpByCollapsing
1808  : public OpRewritePattern<tensor::PadOp> {
1809 public:
1810  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1811  ControlFusionFn foldReshapes,
1812  PatternBenefit benefit = 1)
1813  : OpRewritePattern<tensor::PadOp>(context, benefit),
1814  controlFoldingReshapes(std::move(foldReshapes)) {}
1815 
1816  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1817  PatternRewriter &rewriter) const override {
1818  tensor::ExpandShapeOp reshapeOp =
1819  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1820  if (!reshapeOp)
1821  return failure();
1822  if (!reshapeOp->hasOneUse())
1823  return failure();
1824 
1825  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1826  return rewriter.notifyMatchFailure(padOp,
1827  "fusion blocked by control function");
1828  }
1829 
1830  ArrayRef<int64_t> low = padOp.getStaticLow();
1831  ArrayRef<int64_t> high = padOp.getStaticHigh();
1832  SmallVector<ReassociationIndices> reassociations =
1833  reshapeOp.getReassociationIndices();
1834 
1835  for (auto reInd : reassociations) {
1836  if (reInd.size() == 1)
1837  continue;
1838  if (llvm::any_of(reInd, [&](int64_t ind) {
1839  return low[ind] != 0 || high[ind] != 0;
1840  })) {
1841  return failure();
1842  }
1843  }
1844 
1845  SmallVector<OpFoldResult> newLow, newHigh;
1846  RankedTensorType collapsedType = reshapeOp.getSrcType();
1847  RankedTensorType paddedType = padOp.getResultType();
1848  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1849  SmallVector<OpFoldResult> expandedPaddedSizes(
1850  getMixedValues(reshapeOp.getStaticOutputShape(),
1851  reshapeOp.getOutputShape(), rewriter));
1852  AffineExpr d0, d1, d2;
1853  bindDims(rewriter.getContext(), d0, d1, d2);
1854  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1855  Location loc = reshapeOp->getLoc();
1856  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1857  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1858  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1859  if (reInd.size() == 1) {
1860  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1862  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1863  expandedPaddedSizes[reInd[0]] = paddedSize;
1864  }
1865  newLow.push_back(l);
1866  newHigh.push_back(h);
1867  }
1868 
1869  RankedTensorType collapsedPaddedType =
1870  paddedType.clone(collapsedPaddedShape);
1871  auto newPadOp = rewriter.create<tensor::PadOp>(
1872  loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1873  padOp.getConstantPaddingValue(), padOp.getNofold());
1874 
1875  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1876  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1877  expandedPaddedSizes);
1878 
1879  return success();
1880  }
1881 
1882 private:
1883  ControlFusionFn controlFoldingReshapes;
1884 };
1885 
1886 /// Pattern to collapse dimensions.
1887 template <typename LinalgType>
1888 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1889 public:
1890  CollapseLinalgDimensions(MLIRContext *context,
1891  GetCollapsableDimensionsFn collapseDimensions,
1892  PatternBenefit benefit = 1)
1893  : OpRewritePattern<LinalgType>(context, benefit),
1894  controlCollapseDimension(std::move(collapseDimensions)) {}
1895 
1896  LogicalResult matchAndRewrite(LinalgType op,
1897  PatternRewriter &rewriter) const override {
1898  SmallVector<ReassociationIndices> collapsableIterationDims =
1899  controlCollapseDimension(op);
1900  if (collapsableIterationDims.empty())
1901  return failure();
1902 
1903  // Check if the specified list of dimensions to collapse is a valid list.
1904  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1905  collapsableIterationDims)) {
1906  return rewriter.notifyMatchFailure(
1907  op, "specified dimensions cannot be collapsed");
1908  }
1909 
1910  std::optional<CollapseResult> collapseResult =
1911  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1912  if (!collapseResult) {
1913  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
1914  }
1915  rewriter.replaceOp(op, collapseResult->results);
1916  return success();
1917  }
1918 
1919 private:
1920  GetCollapsableDimensionsFn controlCollapseDimension;
1921 };
1922 
1923 } // namespace
1924 
1925 //===---------------------------------------------------------------------===//
1926 // Methods and patterns that fuse constants with linalg.generic operations.
1927 //===---------------------------------------------------------------------===//
1928 
1929 namespace {
1930 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1931 /// handle cases where the constant is not single-valued.
1932 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1933 public:
1934  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1935  : OpRewritePattern<GenericOp>(context, benefit) {}
1936 
1937  LogicalResult matchAndRewrite(GenericOp genericOp,
1938  PatternRewriter &rewriter) const override {
1939  if (!genericOp.hasPureTensorSemantics())
1940  return failure();
1941  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1942  Operation *def = opOperand->get().getDefiningOp();
1943  TypedAttr constantAttr;
1944  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1945  {
1946  DenseElementsAttr splatAttr;
1947  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1948  splatAttr.isSplat() &&
1949  splatAttr.getType().getElementType().isIntOrFloat()) {
1950  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1951  return true;
1952  }
1953  }
1954  {
1955  IntegerAttr intAttr;
1956  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1957  constantAttr = intAttr;
1958  return true;
1959  }
1960  }
1961  {
1962  FloatAttr floatAttr;
1963  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1964  constantAttr = floatAttr;
1965  return true;
1966  }
1967  }
1968  return false;
1969  };
1970 
1971  auto resultValue = dyn_cast<OpResult>(opOperand->get());
1972  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1973  continue;
1974 
1975  // The operands and the indexing_maps of the fused operation the same as
1976  // the operands and indexing_maps of the generic operations with the
1977  // values at the constant index dropped.
1978  SmallVector<AffineMap> fusedIndexMaps;
1979  SmallVector<Value> fusedOperands;
1980  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1981  fusedIndexMaps.reserve(genericOp->getNumOperands());
1982  fusedOperands.reserve(genericOp.getNumDpsInputs());
1983  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1984  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1985  if (inputOperand == opOperand)
1986  continue;
1987  Value inputValue = inputOperand->get();
1988  fusedIndexMaps.push_back(
1989  genericOp.getMatchingIndexingMap(inputOperand));
1990  fusedOperands.push_back(inputValue);
1991  fusedLocs.push_back(inputValue.getLoc());
1992  }
1993  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1994  fusedIndexMaps.push_back(
1995  genericOp.getMatchingIndexingMap(&outputOperand));
1996 
1997  // Check if the operation shapes to loops map is computable.
1998  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1999  return rewriter.notifyMatchFailure(
2000  genericOp, "fused op loop bound computation failed");
2001  }
2002 
2003  // Create a constant scalar value from the splat constant.
2004  Value scalarConstant =
2005  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2006 
2007  SmallVector<Value> outputOperands = genericOp.getOutputs();
2008  auto fusedOp = rewriter.create<GenericOp>(
2009  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2010  /*inputs=*/fusedOperands,
2011  /*outputs=*/outputOperands,
2012  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2013  genericOp.getIteratorTypes(),
2014  /*doc=*/nullptr,
2015  /*library_call=*/nullptr);
2016 
2017  // Map the block argument corresponding to the replaced argument with the
2018  // scalar constant.
2019  Region &region = genericOp->getRegion(0);
2020  Block &entryBlock = *region.begin();
2021  IRMapping mapping;
2022  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2023  scalarConstant);
2024  Region &fusedRegion = fusedOp->getRegion(0);
2025  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2026  mapping);
2027  rewriter.replaceOp(genericOp, fusedOp->getResults());
2028  return success();
2029  }
2030  return failure();
2031  }
2032 };
2033 
2034 } // namespace
2035 
2036 //===---------------------------------------------------------------------===//
2037 // Miscellaneous patterns that help fusion.
2038 //===---------------------------------------------------------------------===//
2039 
2040 namespace {
2041 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2042 /// value of the `outs` operand is not used within the op. This is only
2043 /// implemented for `linalg.generic` operations for now, but should hold for all
2044 /// linalg structured ops.
2045 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2047 
2048  LogicalResult matchAndRewrite(GenericOp op,
2049  PatternRewriter &rewriter) const override {
2050  rewriter.startOpModification(op);
2051  bool modifiedOutput = false;
2052  Location loc = op.getLoc();
2053  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2054  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2055  Value operandVal = opOperand.get();
2056  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2057  if (!operandType)
2058  continue;
2059 
2060  // If outs is sparse, leave it to the sparsifier.
2062  continue;
2063 
2064  // If outs is already an `empty` operation, nothing to do.
2065  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2066  if (definingOp)
2067  continue;
2068  modifiedOutput = true;
2069  SmallVector<OpFoldResult> mixedSizes =
2070  tensor::getMixedSizes(rewriter, loc, operandVal);
2071  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2072  loc, mixedSizes, operandType.getElementType());
2073  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2074  }
2075  }
2076  if (!modifiedOutput) {
2077  rewriter.cancelOpModification(op);
2078  return failure();
2079  }
2080  rewriter.finalizeOpModification(op);
2081  return success();
2082  }
2083 };
2084 
2085 /// Fold linalg.fill into linalg.generic
2086 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2088 
2089  LogicalResult matchAndRewrite(GenericOp genericOp,
2090  PatternRewriter &rewriter) const override {
2091  if (!genericOp.hasPureTensorSemantics())
2092  return failure();
2093  bool fillFound = false;
2094  Block &payload = genericOp.getRegion().front();
2095  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2096  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2097  continue;
2098  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2099  if (!fillOp)
2100  continue;
2101  fillFound = true;
2102  Value fillVal = fillOp.value();
2103  auto resultType =
2104  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2105  Value convertedVal =
2106  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2107  /*isUnsignedCast =*/false);
2108  rewriter.replaceAllUsesWith(
2109  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2110  }
2111  return success(fillFound);
2112  }
2113 };
2114 } // namespace
2115 
2117  RewritePatternSet &patterns,
2118  const ControlFusionFn &controlFoldingReshapes) {
2119  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2120  controlFoldingReshapes);
2121  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2122  controlFoldingReshapes);
2123  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2124  controlFoldingReshapes);
2125 }
2126 
2128  RewritePatternSet &patterns,
2129  const ControlFusionFn &controlFoldingReshapes) {
2130  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2131  controlFoldingReshapes);
2132  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2133  patterns.getContext(), controlFoldingReshapes);
2134 }
2135 
2137  RewritePatternSet &patterns,
2138  const ControlFusionFn &controlElementwiseOpsFusion) {
2139  auto *context = patterns.getContext();
2140  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2141  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2142  RemoveOutsDependency>(context);
2143  // Add the patterns that clean up dead operands and results.
2145 }
2146 
2148  RewritePatternSet &patterns,
2149  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2150  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2151  CollapseLinalgDimensions<linalg::CopyOp>>(
2152  patterns.getContext(), controlCollapseDimensions);
2153 }
2154 
2155 //===---------------------------------------------------------------------===//
2156 // Passes
2157 //===---------------------------------------------------------------------===//
2158 
2159 namespace {
2160 
2161 /// Pass that fuses generic ops on tensors. Used only for testing.
2162 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2163 // patterns added here heavily depends on the cost function used. Having an
2164 // opinionated pass of this form is not recommended. Deprecate this pass in
2165 // favor of test passes that check the functionality of each of the patterns
2166 // added here individually.
2167 struct LinalgElementwiseOpFusionPass
2168  : public impl::LinalgElementwiseOpFusionPassBase<
2169  LinalgElementwiseOpFusionPass> {
2170  using impl::LinalgElementwiseOpFusionPassBase<
2171  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2172  void runOnOperation() override {
2173  Operation *op = getOperation();
2174  MLIRContext *context = op->getContext();
2175  RewritePatternSet patterns(context);
2176 
2177  // Add folding with reshape by expansion patterns.
2178  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2179  Operation *producer = fusedOperand->get().getDefiningOp();
2180  return producer && producer->hasOneUse();
2181  };
2182 
2183  // Add elementwise op fusion patterns.
2184  populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2185  populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
2187 
2188  // General canonicalization patterns.
2189  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2190  GenericOp::getCanonicalizationPatterns(patterns, context);
2191  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2192  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2193  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2194  patterns);
2195 
2196  // Add constant folding patterns.
2197  populateConstantFoldLinalgOperations(patterns, defaultControlFn);
2198 
2199  // Use TopDownTraversal for compile time reasons
2200  GreedyRewriteConfig grc;
2201  grc.useTopDownTraversal = true;
2202  (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
2203  }
2204 };
2205 
2206 } // namespace
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Expanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static LogicalResult validateDynamicDimExpansion(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Checks if a single dynamic dimension expanded into multiple dynamic dimensions.
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:191
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:615
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:624
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1704
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1735
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:836
ArrayRef< int64_t > ReassociationIndicesRef
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:502
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:504
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.