MLIR  20.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
27 #include <optional>
28 #include <utility>
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 //===---------------------------------------------------------------------===//
39 // Methods and patterns that fuse elementwise `linalg.generic` operations.
40 //===---------------------------------------------------------------------===//
41 
42 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
43 /// the `producer` to use in the fused operation given the indexing map of the
44 /// result of the producer in the consumer.
46  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
47  AffineMap fusedConsumerArgIndexMap) {
48  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
49  // from consumer loop -> consumer arg tensor index/producer result tensor
50  // index. The fused loop is same as the consumer loop. For each producer arg
51  // the indexing map to be computed is a map from consumer loop -> producer
52  // arg tensor index.
53  // producerResultIndexMap is a map from producer loop -> tensor index.
54  // Compute the inverse to get map from tensor index -> producer loop.
55  // The inverse is a map from producer result tensor index -> producer loop.
56  AffineMap invProducerResultIndexMap =
57  inversePermutation(producerResultIndexMap);
58  assert(invProducerResultIndexMap &&
59  "expected producer result indexing map to be invertible");
60 
61  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
62  // argMap is a map from producer loop -> producer arg tensor index.
63  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
64 
65  // Compose argMap with invProducerResultIndexMap to get a map from
66  // producer result tensor index -> producer arg tensor index.
67  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
68 
69  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
70  // consumer loop/ fused loop -> producer arg tensor index.
71  return t1.compose(fusedConsumerArgIndexMap);
72 }
73 
74 // Checks if the given operand can be dropped, and the remaining operands
75 // of the fused producer & consumer after the fusion can still compute the
76 // bounds of the op.
78  GenericOp producer, GenericOp consumer,
79  ArrayRef<OpOperand *> opOperandsToIgnore) {
80  SmallVector<AffineMap> indexingMaps;
81 
82  SmallVector<GenericOp> ops = {producer, consumer};
83  for (auto &op : ops) {
84  for (auto &opOperand : op->getOpOperands()) {
85  if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
86  continue;
87  }
88  indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
89  }
90  }
91  if (indexingMaps.empty()) {
92  // If there are no indexing maps, the operand can only be dropped
93  // if neither op has loops.
94  return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
95  }
96 
97  // The concatanation of the remained indexing maps must be invertible, so
98  // the bounds of the op can be still computed after dropping the selected
99  // operand. inversePermutation returns an empty AffineMap in case the
100  // concatanated indexing maps are not invertible.
102  indexingMaps, producer.getContext())) != AffineMap();
103 }
104 
105 /// Returns a set of indices of the producer's results which would
106 /// be preserved after the fusion.
107 /// * There is a chance that the implementation of the transformation does not
108 /// agree with the result of this method. This function gives a prediction based
109 /// on an optimized fusion.
111  GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
112  llvm::SmallDenseSet<int> preservedProducerResults;
113  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
114 
115  // The fusedOperand will be removed during the fusion
116  opOperandsToIgnore.emplace_back(fusedOperand);
117 
118  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
119  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
120  opOperandsToIgnore.emplace_back(outputOperand);
121  if (producer.payloadUsesValueFromOperand(outputOperand) ||
122  !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
123  opOperandsToIgnore) ||
124  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
125  return user != consumer.getOperation();
126  })) {
127  preservedProducerResults.insert(producerResult.index());
128 
129  // In case the operand can't be dropped
130  (void)opOperandsToIgnore.pop_back_val();
131  }
132  }
133  return preservedProducerResults;
134 }
135 
136 /// Conditions for elementwise fusion of generic operations.
138  if (!fusedOperand)
139  return false;
140 
141  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
142  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
143 
144  // Check producer and consumer are generic ops.
145  if (!producer || !consumer)
146  return false;
147 
148  // Consumer can have mixed semantics, just check operand itself has tensor
149  // type. Producer must have full tensor semantics to avoid potential
150  // aliasing between producer and consumer memrefs.
151  if (!producer.hasPureTensorSemantics() ||
152  !isa<RankedTensorType>(fusedOperand->get().getType()))
153  return false;
154 
155  // Verify that
156  // - the producer has all "parallel" iterator type.
157  if (producer.getNumParallelLoops() != producer.getNumLoops())
158  return false;
159 
160  // Only allow fusing the producer of an input operand for now.
161  // TODO: allow fusing the producer of an output operand.
162  if (!consumer.isDpsInput(fusedOperand))
163  return false;
164 
165  // Get the consumer index map. The number of results of the consumer index
166  // map must match the number of loops of the producer.
167  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
168  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
169  return false;
170 
171  // Finally the index_map for the result must be invertible. For now just
172  // verify it is a permutation.
173  AffineMap producerResultIndexMap =
174  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
175  if (!producerResultIndexMap.isPermutation())
176  return false;
177 
178  // Ensure that the fusion does not remove size information required to
179  // get the loop bounds. For non-reduction generics, this is trivially the
180  // case due to the output operand. For reductions, we need to check that after
181  // the fusion, each loop dimension has at least one input that defines it.
182  if ((consumer.getNumReductionLoops())) {
183  BitVector coveredDims(consumer.getNumLoops(), false);
184 
185  auto addToCoveredDims = [&](AffineMap map) {
186  for (auto result : map.getResults())
187  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
188  coveredDims[dimExpr.getPosition()] = true;
189  };
190 
191  for (auto pair :
192  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
193  Value operand = std::get<0>(pair);
194  if (operand == fusedOperand->get())
195  continue;
196  AffineMap operandMap = std::get<1>(pair);
197  addToCoveredDims(operandMap);
198  }
199 
200  for (OpOperand *operand : producer.getDpsInputOperands()) {
201  AffineMap newIndexingMap =
203  operand, producerResultIndexMap, consumerIndexMap);
204  addToCoveredDims(newIndexingMap);
205  }
206  if (!coveredDims.all())
207  return false;
208  }
209 
210  return true;
211 }
212 
213 /// Generate the region of the fused tensor operation. The region of the fused
214 /// op must be empty.
216  RewriterBase &rewriter, GenericOp fusedOp,
217  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
218  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
219  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
220  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
221  // Build the region of the fused op.
222  Block &producerBlock = producer->getRegion(0).front();
223  Block &consumerBlock = consumer->getRegion(0).front();
224  OpBuilder::InsertionGuard guard(rewriter);
225  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
226  IRMapping mapper;
227 
228  // 2. Add an index operation for every fused loop dimension and use the
229  // `consumerToProducerLoopsMap` to map the producer indices.
230  if (producer.hasIndexSemantics()) {
231  // Add an index operation for every fused loop dimension.
232  unsigned numFusedOpLoops =
233  std::max(producer.getNumLoops(), consumer.getNumLoops());
234  SmallVector<Value> fusedIndices;
235  fusedIndices.reserve(numFusedOpLoops);
236  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
237  std::back_inserter(fusedIndices), [&](uint64_t dim) {
238  return rewriter.create<IndexOp>(producer.getLoc(), dim);
239  });
240  for (IndexOp indexOp :
241  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
242  Value newIndex = rewriter.create<affine::AffineApplyOp>(
243  producer.getLoc(),
244  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
245  mapper.map(indexOp.getResult(), newIndex);
246  }
247  }
248  // TODO: allow fusing the producer of an output operand.
249  assert(consumer.isDpsInput(fusedOperand) &&
250  "expected producer of input operand");
251  // 3. Consumer input operands up to consumerIdx (exclusive).
252  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
253  fusedOperand->getOperandNumber())) // input assumption.
254  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
255 
256  // Replacing consumerIdx requires getting the cloned, yielded, value from
257  // the (cloned) producer block. This happens in step 9.
258 
259  // 4. Splice in producer's input operands.
260  for (BlockArgument bbArg :
261  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
262  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
263 
264  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
265  for (BlockArgument bbArg :
266  consumerBlock.getArguments()
267  .take_front(consumer.getNumDpsInputs())
268  .drop_front(fusedOperand->getOperandNumber() + 1))
269  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
270 
271  // 6. All of the producer's output operands
272  for (const auto &bbArg : llvm::enumerate(
273  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
274  if (!preservedProducerResults.count(bbArg.index()))
275  continue;
276  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
277  bbArg.value().getLoc()));
278  }
279 
280  // 7. All of consumer's output operands.
281  for (BlockArgument bbArg :
282  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
283  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
284 
285  // 8. Clone all producer operations except for the yield and index operations
286  // to the fused operation.
287  for (auto &op : producerBlock.without_terminator()) {
288  if (!isa<IndexOp>(op))
289  rewriter.clone(op, mapper);
290  }
291  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
292  // forward the yield operand.
293  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
294  unsigned producerResultNumber =
295  cast<OpResult>(fusedOperand->get()).getResultNumber();
296  Value replacement =
297  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
298 
299  // Sanity checks, if replacement is not already in the mapper then it must be
300  // produced outside.
301  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
302  if (auto bb = dyn_cast<BlockArgument>(replacement))
303  assert(bb.getOwner() != &producerBlock &&
304  "yielded block argument must have been mapped");
305  else
306  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
307  "yielded value must have been mapped");
308  }
309  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
310  replacement);
311  // 10. Clone operations from the consumer to the fused op.
312  for (auto &op : consumerBlock.without_terminator())
313  rewriter.clone(op, mapper);
314 
315  // 11. Include the final yield (which is the remapped values for all the
316  // yield)
317  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
318  SmallVector<Value> fusedYieldValues;
319  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
320  consumerYieldOp.getNumOperands());
321  for (const auto &producerYieldVal :
322  llvm::enumerate(producerYieldOp.getOperands())) {
323  if (preservedProducerResults.count(producerYieldVal.index()))
324  fusedYieldValues.push_back(
325  mapper.lookupOrDefault(producerYieldVal.value()));
326  }
327  for (auto consumerYieldVal : consumerYieldOp.getOperands())
328  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
329  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
330 
331  // Sanity checks.
332  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
333  "Ill-formed GenericOp region");
334 }
335 
336 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
338  OpOperand *fusedOperand) {
339  assert(areElementwiseOpsFusable(fusedOperand) &&
340  "expected elementwise operation pre-conditions to pass");
341  auto producerResult = cast<OpResult>(fusedOperand->get());
342  auto producer = cast<GenericOp>(producerResult.getOwner());
343  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
344  // TODO: allow fusing the producer of an output operand.
345  assert(consumer.isDpsInput(fusedOperand) &&
346  "expected producer of input operand");
347  /// Find the results of the producer that have uses outside of the consumer,
348  /// after the fusion.
349  llvm::SmallDenseSet<int> preservedProducerResults =
351  fusedOperand);
352 
353  // Compute the fused operands list and indexing maps.
354  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
355  SmallVector<Type> fusedResultTypes;
356  SmallVector<AffineMap> fusedIndexMaps;
357  fusedInputOperands.reserve(producer.getNumDpsInputs() +
358  consumer.getNumDpsInputs());
359  fusedOutputOperands.reserve(preservedProducerResults.size() +
360  consumer.getNumDpsInits());
361  fusedResultTypes.reserve(preservedProducerResults.size() +
362  consumer.getNumDpsInits());
363  fusedIndexMaps.reserve(producer->getNumOperands() +
364  consumer->getNumOperands());
365  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
366  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
367  auto consumerInputs = consumer.getDpsInputOperands();
368  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
369  return operand == fusedOperand;
370  });
371  assert(it != consumerInputs.end() && "expected to find the consumer operand");
372  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
373  fusedInputOperands.push_back(opOperand->get());
374  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
375  }
376  // 4. Splice in producer's input operands/maps.
377  AffineMap producerResultIndexMap =
378  producer.getIndexingMapMatchingResult(producerResult);
379  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
380  fusedInputOperands.push_back(opOperand->get());
381  // Compute indexing maps for the producer args in the fused operation.
383  opOperand, producerResultIndexMap,
384  consumer.getMatchingIndexingMap(fusedOperand));
385  fusedIndexMaps.push_back(map);
386  }
387  // 5. Remaining consumer's input operands/maps (drop past index
388  // `consumerIdx`).
389  for (OpOperand *opOperand :
390  llvm::make_range(std::next(it), consumerInputs.end())) {
391  fusedInputOperands.push_back(opOperand->get());
392  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
393  }
394 
395  // 6. Collect all of the producer outputs.
396  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
397  if (!preservedProducerResults.count(opOperand.index()))
398  continue;
399 
400  fusedOutputOperands.push_back(opOperand.value().get());
402  &opOperand.value(), producerResultIndexMap,
403  consumer.getMatchingIndexingMap(fusedOperand));
404  fusedIndexMaps.push_back(map);
405  fusedResultTypes.push_back(opOperand.value().get().getType());
406  }
407 
408  // 7. All of consumer's output operands (skip operands: added by the builder).
409  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
410  fusedOutputOperands.push_back(opOperand.get());
411  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
412  Type resultType = opOperand.get().getType();
413  if (!isa<MemRefType>(resultType))
414  fusedResultTypes.push_back(resultType);
415  }
416 
417  // Generate the fused op.
418  auto fusedOp = rewriter.create<GenericOp>(
419  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
420  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
421  consumer.getIteratorTypes(),
422  /*doc=*/nullptr,
423  /*library_call=*/nullptr);
424  if (!fusedOp.getShapesToLoopsMap()) {
425  // Fused op has invalid indexing maps. Typically this means something is off
426  // in the input, but going ahead here would result in verification errors.
427  // So cleanup and abort.
428  rewriter.eraseOp(fusedOp);
429  return rewriter.notifyMatchFailure(
430  fusedOp, "fused op failed loop bound computation check");
431  }
432 
433  // Construct an AffineMap from consumer loops to producer loops.
434  // consumer loop -> tensor index
435  AffineMap consumerResultIndexMap =
436  consumer.getMatchingIndexingMap(fusedOperand);
437  // tensor index -> producer loop
438  AffineMap invProducerResultIndexMap =
439  inversePermutation(producerResultIndexMap);
440  assert(invProducerResultIndexMap &&
441  "expected producer result indexig map to be invertible");
442  // consumer loop -> producer loop
443  AffineMap consumerToProducerLoopsMap =
444  invProducerResultIndexMap.compose(consumerResultIndexMap);
445 
447  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
448  consumer.getNumLoops(), preservedProducerResults);
450  result.fusedOp = fusedOp;
451  int resultNum = 0;
452  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
453  if (preservedProducerResults.count(index))
454  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
455  for (auto consumerResult : consumer->getResults())
456  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
457  return result;
458 }
459 
460 namespace {
461 /// Patterns to fuse a generic op, with the producer of its operands.
462 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
463 public:
464  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
465  PatternBenefit benefit = 1)
466  : OpRewritePattern<GenericOp>(context, benefit),
467  controlFn(std::move(fun)) {}
468 
469  LogicalResult matchAndRewrite(GenericOp genericOp,
470  PatternRewriter &rewriter) const override {
471  // Find the first operand that is defined by another generic op on tensors.
472  for (OpOperand &opOperand : genericOp->getOpOperands()) {
473  if (!areElementwiseOpsFusable(&opOperand))
474  continue;
475  if (!controlFn(&opOperand))
476  continue;
477 
478  Operation *producer = opOperand.get().getDefiningOp();
479 
480  // Find the producer of the operand.
481  FailureOr<ElementwiseOpFusionResult> fusionResult =
482  fuseElementwiseOps(rewriter, &opOperand);
483  if (failed(fusionResult))
484  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
485 
486  // Perform the fusion.
487  for (auto [origVal, replacement] : fusionResult->replacements) {
488  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
489  // Only replace consumer uses.
490  return use.get().getDefiningOp() != producer;
491  });
492  }
493  rewriter.eraseOp(genericOp);
494  return success();
495  }
496  return failure();
497  }
498 
499 private:
500  ControlFusionFn controlFn;
501 };
502 } // namespace
503 
504 //===---------------------------------------------------------------------===//
505 // Methods and patterns that fuse reshape ops with elementwise operations by
506 // expanding the dimensionality of the elementwise operations.
507 //===---------------------------------------------------------------------===//
508 
509 /// Conditions for folding a structured linalg operation with a reshape op by
510 /// expanding the iteration space dimensionality for tensor operations. These
511 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
512 /// the following fusion pattern.
513 ///
514 /// Consider
515 ///
516 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
517 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
518 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
519 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
520 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
521 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
522 ///
523 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
524 /// is increased to match the result (operand) of the tensor.expand_shape.
525 /// The indexing_map of the fused tensor in the `linalgOp` and the
526 /// reassociation map helps compute the indexing maps of the modified op.
527 /// For the above example, based on the reassociation map it
528 /// can be concluded that
529 ///
530 /// - The loop used to access the first dimension of the fused tensor is split
531 /// into two.
532 /// - The loop used to access the second dimension of the fused tensor is kept
533 /// as is.
534 /// - The loop used to access the third dimension of the fused tensor is split
535 /// into three.
536 ///
537 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
538 /// op, then
539 ///
540 /// d0 -> e0, e1
541 /// d1 -> e2, e3, e4
542 /// d2 -> e5
543 ///
544 /// substituting this, the structured op can be rewritten as
545 ///
546 /// %d = linalg.generic ins(%0, %1 : )
547 /// indexing_maps =
548 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
549 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
550 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
551 ///
552 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
553 /// to make it consistent
554 ///
555 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
556 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
557 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
558 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
559 ///
560 /// The added reshapes are again expanding patterns, so they will get fused
561 /// with its producers if possible.
562 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
563  OpOperand *fusableOpOperand) {
564  // Is fusable only if:
565  // - All the indexing maps for operands and results are projected
566  // permutations.
567  // - The fused tensor is not a scalar.
568  // - All the loops for the reshaped operand are parallel loops.
569  SmallVector<utils::IteratorType> iteratorTypes =
570  linalgOp.getIteratorTypesArray();
571  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
572  return linalgOp.hasPureTensorSemantics() &&
573  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
574  [](Attribute attr) {
575  return cast<AffineMapAttr>(attr)
576  .getValue()
577  .isProjectedPermutation();
578  }) &&
579  operandMap.getNumResults() > 0 &&
580  llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
581  return isParallelIterator(
582  iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
583  });
584 }
585 
586 namespace {
587 /// Information needed to expand a generic operation to fold the reshape with
588 /// it.
589 class ExpansionInfo {
590 public:
591  // Computes the mapping from original dimensions of the op to the dimensions
592  // of the expanded op given the `indexingMap` of the fused operand/result of
593  // the generic op, the `reassocationMaps` of the reshape op and the shape of
594  // the expanded op.
595  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
596  ArrayRef<AffineMap> reassociationMaps,
597  ArrayRef<int64_t> expandedShape,
598  ArrayRef<int64_t> collapsedShape,
599  PatternRewriter &rewriter);
600  unsigned getOrigOpNumDims() const { return reassociation.size(); }
601  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
602  ReassociationIndicesRef getExpandedDims(unsigned i) const {
603  return reassociation[i];
604  }
605  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
606  return expandedShapeMap[i];
607  }
608  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
609 
610 private:
611  /// Reassociation from the dimensions in the original operation to the
612  /// dimension of the expanded operation.
613  SmallVector<ReassociationIndices> reassociation;
614  /// Mapping from extent of loops in the original operation, to the extent of
615  /// loops in the expanded operation.
616  SmallVector<SmallVector<int64_t>> expandedShapeMap;
617  /// Extent of the loop in the original operation.
618  SmallVector<int64_t> originalLoopExtent;
619  unsigned expandedOpNumDims;
620 };
621 } // namespace
622 
623 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624  OpOperand *fusableOpOperand,
625  ArrayRef<AffineMap> reassociationMaps,
626  ArrayRef<int64_t> expandedShape,
627  ArrayRef<int64_t> collapsedShape,
628  PatternRewriter &rewriter) {
629  if (reassociationMaps.empty())
630  return failure();
631  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
632 
633  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
634  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
635 
636  reassociation.clear();
637  expandedShapeMap.clear();
638  // Compute the number of dimension in the expanded op that correspond to each
639  // dimension of the original op.
640  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
641  expandedShapeMap.resize(fusedIndexMap.getNumDims());
642  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
643  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
644  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645  numExpandedDims[pos] = foldedDims.getNumResults();
646  ArrayRef<int64_t> shape =
647  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
648  expandedShapeMap[pos].assign(shape.begin(), shape.end());
649  }
650  // The remaining dimensions remain the same.
651  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
652  if (expandedShapeMap[i].empty())
653  expandedShapeMap[i] = {originalLoopExtent[i]};
654 
655  // Compute reassociation map from the original op to the expanded op.
656  unsigned sum = 0;
657  reassociation.reserve(fusedIndexMap.getNumDims());
658  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
659  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
660  reassociation.emplace_back(seq.begin(), seq.end());
661  sum += numFoldedDim.value();
662  }
663  expandedOpNumDims = sum;
664  return success();
665 }
666 
667 /// Expanding the body of a linalg operation requires adaptations of the
668 /// accessed loop indices. Specifically, access of indices in the original
669 /// operation need to be replaced with linearizations of indices in the expanded
670 /// op. That requires the shape of the expanded dimensions to be static (at
671 /// least all but the most significant). For now check that these are all
672 /// statically sized. Note that this could be extended to handle dynamic case,
673 /// but the implementation below uses `affine.apply` which seems to have issues
674 /// when the shapes are not static.
675 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
676  const ExpansionInfo &expansionInfo,
677  PatternRewriter &rewriter) {
678  if (!linalgOp.hasIndexSemantics())
679  return success();
680  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
681  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
682  if (expandedShape.size() == 1)
683  continue;
684  for (int64_t shape : expandedShape.drop_front()) {
685  if (ShapedType::isDynamic(shape)) {
686  return rewriter.notifyMatchFailure(
687  linalgOp, "cannot expand due to index semantics and dynamic dims");
688  }
689  }
690  }
691  return success();
692 }
693 
694 /// Return the indexing map to use in the expanded op for a given the
695 /// `indexingMap` of the original operation.
696 static AffineMap
698  const ExpansionInfo &expansionInfo) {
699  SmallVector<AffineExpr> newExprs;
700  for (AffineExpr expr : indexingMap.getResults()) {
701  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
702  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
703  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
704  return builder.getAffineDimExpr(static_cast<unsigned>(v));
705  }));
706  newExprs.append(expandedExprs.begin(), expandedExprs.end());
707  }
708  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
709  indexingMap.getNumSymbols(), newExprs,
710  builder.getContext());
711 }
712 
713 /// Return the type of the operand/result to use in the expanded op given the
714 /// type in the original op.
715 static RankedTensorType getExpandedType(RankedTensorType originalType,
716  AffineMap indexingMap,
717  const ExpansionInfo &expansionInfo) {
718  SmallVector<int64_t> expandedShape;
719  for (AffineExpr expr : indexingMap.getResults()) {
720  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
721  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
722  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
723  }
724  return RankedTensorType::get(expandedShape, originalType.getElementType());
725 }
726 
727 /// Returns the reassociation maps to use in the `tensor.expand_shape`
728 /// operation to convert the operands of the original operation to operands of
729 /// the expanded operation. The same method is used to compute the
730 /// `tensor.collapse_shape` used to collapse the result of the expanded
731 /// op to get the value that can replace all uses of the results of the original
732 /// op.
735  const ExpansionInfo &expansionInfo) {
736  SmallVector<ReassociationIndices> reassociation;
737  unsigned numReshapeDims = 0;
738  for (AffineExpr expr : indexingMap.getResults()) {
739  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
740  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
741  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
742  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
743  reassociation.emplace_back(std::move(indices));
744  numReshapeDims += numExpandedDims;
745  }
746  return reassociation;
747 }
748 
749 /// Update the body of an expanded linalg operation having index semantics. The
750 /// indices of the original operation need to be recovered by linearizing the
751 /// indices of the correspoding dimensions of the expanded operation. For now it
752 /// is assumed that the shapes of the expanded operation needed for
753 /// linearization are static.
755  Location loc, Region &fusedRegion,
756  const ExpansionInfo &expansionInfo) {
757  // Replace the original indices by the linearization of the expanded indices.
758  for (IndexOp indexOp :
759  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
760  ArrayRef<int64_t> expandedDims =
761  expansionInfo.getExpandedDims(indexOp.getDim());
762  assert(!expandedDims.empty() && "expected valid expansion info");
763 
764  // Skip index operations that are not affected by the expansion.
765  if (expandedDims.size() == 1 &&
766  expandedDims.front() == (int64_t)indexOp.getDim())
767  continue;
768 
769  // Linearize the expanded indices of the original index dimension.
770  OpBuilder::InsertionGuard guard(rewriter);
771  rewriter.setInsertionPointAfter(indexOp);
772  ArrayRef<int64_t> expandedDimsShape =
773  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
774  SmallVector<Value> expandedIndices;
775  expandedIndices.reserve(expandedDims.size() - 1);
776  llvm::transform(
777  expandedDims.drop_front(), std::back_inserter(expandedIndices),
778  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
779  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
780  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
781  assert(!ShapedType::isDynamic(std::get<0>(it)));
782  AffineExpr idx, acc;
783  bindDims(rewriter.getContext(), idx, acc);
784  newIndex = rewriter.create<affine::AffineApplyOp>(
785  indexOp.getLoc(), idx + acc * std::get<0>(it),
786  ValueRange{std::get<1>(it), newIndex});
787  }
788  rewriter.replaceOp(indexOp, newIndex);
789  }
790 }
791 
792 /// Checks if a single dynamic dimension expanded into multiple dynamic
793 /// dimensions.
794 static LogicalResult
795 validateDynamicDimExpansion(LinalgOp linalgOp,
796  const ExpansionInfo &expansionInfo,
797  PatternRewriter &rewriter) {
798  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
799  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
800  if (expandedShape.size() == 1)
801  continue;
802  bool foundDynamic = false;
803  for (int64_t shape : expandedShape) {
804  if (!ShapedType::isDynamic(shape))
805  continue;
806  if (foundDynamic) {
807  return rewriter.notifyMatchFailure(
808  linalgOp, "cannot infer expanded shape with multiple dynamic "
809  "dims in the same reassociation group");
810  }
811  foundDynamic = true;
812  }
813  }
814  return success();
815 }
816 
817 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
818 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
819 /// that those conditions have been satisfied.
820 static std::optional<SmallVector<Value>>
821 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
822  OpOperand *fusableOpOperand,
823  PatternRewriter &rewriter) {
824  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
825  "preconditions for fuse operation failed");
826 
827  Location loc = linalgOp.getLoc();
828  // Check if reshape is expanding or collapsing.
829  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
830  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
831  bool isExpanding = (expandingReshapeOp != nullptr);
832  RankedTensorType expandedType = isExpanding
833  ? expandingReshapeOp.getResultType()
834  : collapsingReshapeOp.getSrcType();
835  RankedTensorType collapsedType = isExpanding
836  ? expandingReshapeOp.getSrcType()
837  : collapsingReshapeOp.getResultType();
838 
839  ExpansionInfo expansionInfo;
840  if (failed(expansionInfo.compute(
841  linalgOp, fusableOpOperand,
842  isExpanding ? expandingReshapeOp.getReassociationMaps()
843  : collapsingReshapeOp.getReassociationMaps(),
844  expandedType.getShape(), collapsedType.getShape(), rewriter)))
845  return std::nullopt;
846 
847  // TODO: With the support of multiple dynamic dims expansion in
848  // tensor.expand_shape op, this case can be handled.
849  if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
850  return std::nullopt;
851 
852  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
853  return std::nullopt;
854 
855  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
856  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
857  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
858  }));
859 
860  // Set insertion point to the generic op.
861  OpBuilder::InsertionGuard g(rewriter);
862  rewriter.setInsertionPoint(linalgOp);
863 
864  SmallVector<Value> expandedOpOperands;
865  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
866  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
867  if (opOperand == fusableOpOperand) {
868  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
869  : collapsingReshapeOp.getSrc());
870  continue;
871  }
872  if (auto opOperandType =
873  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
874  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
875  RankedTensorType expandedOperandType =
876  getExpandedType(opOperandType, indexingMap, expansionInfo);
877  if (expandedOperandType != opOperand->get().getType()) {
878  // Reshape the operand to get the right type.
879  SmallVector<ReassociationIndices> reassociation =
880  getReassociationForExpansion(indexingMap, expansionInfo);
882  [&](const Twine &msg) {
883  return rewriter.notifyMatchFailure(linalgOp, msg);
884  },
885  opOperandType.getShape(), expandedOperandType.getShape(),
886  reassociation,
887  /*isExpandingReshape=*/true)))
888  return std::nullopt;
889  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
890  loc, expandedOperandType, opOperand->get(), reassociation));
891  continue;
892  }
893  }
894  expandedOpOperands.push_back(opOperand->get());
895  }
896 
897  SmallVector<Value> outputs;
898  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
899  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
900  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
901  RankedTensorType expandedOutputType =
902  getExpandedType(opOperandType, indexingMap, expansionInfo);
903  if (expandedOutputType != opOperand.get().getType()) {
904  SmallVector<ReassociationIndices> reassociation =
905  getReassociationForExpansion(indexingMap, expansionInfo);
907  [&](const Twine &msg) {
908  return rewriter.notifyMatchFailure(linalgOp, msg);
909  },
910  opOperandType.getShape(), expandedOutputType.getShape(),
911  reassociation,
912  /*isExpandingReshape=*/true)))
913  return std::nullopt;
914  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
915  loc, expandedOutputType, opOperand.get(), reassociation));
916  } else {
917  outputs.push_back(opOperand.get());
918  }
919  }
920 
921  // The iterator types of the expanded op are all parallel.
922  SmallVector<utils::IteratorType> iteratorTypes(
923  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
924  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
925  for (auto j : expansionInfo.getExpandedDims(i))
926  iteratorTypes[j] = type;
927 
928  TypeRange resultTypes = ValueRange(outputs).getTypes();
929  auto fusedOp =
930  rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
931  /*inputs=*/expandedOpOperands, outputs,
932  expandedOpIndexingMaps, iteratorTypes);
933  Region &fusedRegion = fusedOp->getRegion(0);
934  Region &originalRegion = linalgOp->getRegion(0);
935  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
936 
937  // Update the index accesses after the expansion.
938  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
939 
940  // Reshape the result values to their original shape if this is a collapsing
941  // reshape folded into its consumer.
942  SmallVector<Value> resultVals;
943  for (OpResult opResult : linalgOp->getOpResults()) {
944  int64_t resultNumber = opResult.getResultNumber();
945  if (resultTypes[resultNumber] != opResult.getType()) {
946  SmallVector<ReassociationIndices> reassociation =
948  linalgOp.getMatchingIndexingMap(
949  linalgOp.getDpsInitOperand(resultNumber)),
950  expansionInfo);
951  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
952  linalgOp.getLoc(), opResult.getType(),
953  fusedOp->getResult(resultNumber), reassociation));
954  } else {
955  resultVals.push_back(fusedOp->getResult(resultNumber));
956  }
957  }
958  // Assuming a single result.
959  return resultVals;
960 }
961 
962 namespace {
963 
964 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
965 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
966 /// in the consumer is expanded.
967 class FoldWithProducerReshapeOpByExpansion
968  : public OpInterfaceRewritePattern<LinalgOp> {
969 public:
970  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
971  ControlFusionFn foldReshapes,
972  PatternBenefit benefit = 1)
973  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
974  controlFoldingReshapes(std::move(foldReshapes)) {}
975 
976  LogicalResult matchAndRewrite(LinalgOp linalgOp,
977  PatternRewriter &rewriter) const override {
978  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
979  tensor::CollapseShapeOp reshapeOp =
980  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
981  if (!reshapeOp)
982  continue;
983  // Fold only if
984  // - The tensor reshape op is folding.
985  // - All constraints of fusing with reshape by expansion are met.
986  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
987  (!controlFoldingReshapes(opOperand)))
988  continue;
989 
990  std::optional<SmallVector<Value>> replacementValues =
991  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
992  if (!replacementValues)
993  return failure();
994  rewriter.replaceOp(linalgOp, *replacementValues);
995  return success();
996  }
997  return failure();
998  }
999 
1000 private:
1001  ControlFusionFn controlFoldingReshapes;
1002 };
1003 
1004 class FoldPadWithProducerReshapeOpByExpansion
1005  : public OpRewritePattern<tensor::PadOp> {
1006 public:
1007  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1008  ControlFusionFn foldReshapes,
1009  PatternBenefit benefit = 1)
1010  : OpRewritePattern<tensor::PadOp>(context, benefit),
1011  controlFoldingReshapes(std::move(foldReshapes)) {}
1012 
1013  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1014  PatternRewriter &rewriter) const override {
1015  tensor::CollapseShapeOp reshapeOp =
1016  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1017  if (!reshapeOp)
1018  return failure();
1019  if (!reshapeOp->hasOneUse())
1020  return failure();
1021 
1022  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1023  return rewriter.notifyMatchFailure(padOp,
1024  "fusion blocked by control function");
1025  }
1026 
1027  ArrayRef<int64_t> low = padOp.getStaticLow();
1028  ArrayRef<int64_t> high = padOp.getStaticHigh();
1029  SmallVector<ReassociationIndices> reassociations =
1030  reshapeOp.getReassociationIndices();
1031 
1032  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1033  if (reInd.size() != 1 && (l != 0 || h != 0))
1034  return failure();
1035  }
1036 
1037  SmallVector<OpFoldResult> newLow, newHigh;
1038  RankedTensorType expandedType = reshapeOp.getSrcType();
1039  RankedTensorType paddedType = padOp.getResultType();
1040  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1041  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1042  if (reInd.size() == 1) {
1043  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1044  }
1045  for (size_t i = 0; i < reInd.size(); ++i) {
1046  newLow.push_back(padOp.getMixedLowPad()[idx]);
1047  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1048  }
1049  }
1050 
1051  Location loc = padOp->getLoc();
1052  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1053  auto newPadOp = rewriter.create<tensor::PadOp>(
1054  loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1055  padOp.getConstantPaddingValue(), padOp.getNofold());
1056 
1057  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1058  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1059 
1060  return success();
1061  }
1062 
1063 private:
1064  ControlFusionFn controlFoldingReshapes;
1065 };
1066 
1067 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1068 /// by expanding the dimensionality of the loop in the producer op.
1069 struct FoldReshapeWithGenericOpByExpansion
1070  : public OpRewritePattern<tensor::ExpandShapeOp> {
1071 
1072  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1073  ControlFusionFn foldReshapes,
1074  PatternBenefit benefit = 1)
1075  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1076  controlFoldingReshapes(std::move(foldReshapes)) {}
1077 
1078  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1079  PatternRewriter &rewriter) const override {
1080  // Fold only if all constraints of fusing with reshape by expansion are met.
1081  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1082  if (!producerResult) {
1083  return rewriter.notifyMatchFailure(reshapeOp,
1084  "source not produced by an operation");
1085  }
1086 
1087  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1088  if (!producer) {
1089  return rewriter.notifyMatchFailure(reshapeOp,
1090  "producer not a generic op");
1091  }
1092 
1094  producer,
1095  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1096  return rewriter.notifyMatchFailure(
1097  reshapeOp, "failed preconditions of fusion with producer generic op");
1098  }
1099 
1100  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1101  return rewriter.notifyMatchFailure(reshapeOp,
1102  "fusion blocked by control function");
1103  }
1104 
1105  std::optional<SmallVector<Value>> replacementValues =
1107  producer, reshapeOp,
1108  producer.getDpsInitOperand(producerResult.getResultNumber()),
1109  rewriter);
1110  if (!replacementValues) {
1111  return rewriter.notifyMatchFailure(reshapeOp,
1112  "fusion by expansion failed");
1113  }
1114 
1115  // Find the replacement for the reshape op. Since the replacements have the
1116  // same type as the returns of the original generic op, the consumer reshape
1117  // op can be replaced by the source of the collapse_shape op that defines
1118  // the replacement.
1119  Value reshapeReplacement =
1120  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1121  .getResultNumber()];
1122  if (auto collapseOp =
1123  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1124  reshapeReplacement = collapseOp.getSrc();
1125  }
1126  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1127  rewriter.replaceOp(producer, *replacementValues);
1128  return success();
1129  }
1130 
1131 private:
1132  ControlFusionFn controlFoldingReshapes;
1133 };
1134 } // namespace
1135 
1136 //===---------------------------------------------------------------------===//
1137 // Methods and patterns to fuse reshape with linalg.generic operations by
1138 // contraction of dimensions.
1139 //===---------------------------------------------------------------------===//
1140 
1141 /// For a given list of indices in the range of the `indexingMap` that are
1142 /// folded, return the indices of the corresponding domain. Return
1143 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1144 /// reassociation are distinct.
1145 static ReassociationIndices
1147  ReassociationIndicesRef rangeReassociation) {
1148  assert(indexingMap.isProjectedPermutation() &&
1149  "expected projected permutation");
1150 
1151  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1152  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1153  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1154  }));
1155  // The projected permutation semantics ensures that there is no repetition of
1156  // the domain indices.
1157  return domainReassociation;
1158 }
1159 
1160 /// For a given `dimSequence`, check if the sequence is conserved in the
1161 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1162 /// Non-existence of the sequence returns true as well.
1164  ReassociationIndicesRef dimSequence) {
1165  assert(!dimSequence.empty() &&
1166  "expected non-empty list for dimension sequence");
1167  assert(indexingMap.isProjectedPermutation() &&
1168  "expected indexing map to be projected permutation");
1169 
1170  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1171  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1172 
1173  unsigned dimSequenceStart = dimSequence[0];
1174  for (const auto &expr : enumerate(indexingMap.getResults())) {
1175  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1176  // 1. Check if this start of the sequence.
1177  if (dimInMapStart == dimSequenceStart) {
1178  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1179  return false;
1180  // 1a. Check if sequence is preserved.
1181  for (const auto &dimInSequence : enumerate(dimSequence)) {
1182  unsigned dimInMap =
1183  cast<AffineDimExpr>(
1184  indexingMap.getResult(expr.index() + dimInSequence.index()))
1185  .getPosition();
1186  if (dimInMap != dimInSequence.value())
1187  return false;
1188  }
1189  // Found the sequence. Projected permutation
1190  // enforces that all AffineDimExprs in the result are unique, so no
1191  // further checks are needed.
1192  return true;
1193  }
1194  // 2. If position in the expr (which is of type AffineDimExpr) is part
1195  // of sequence, return false here. This implies the entire sequence does not
1196  // exist in the indexing map.
1197  if (sequenceElements.count(dimInMapStart))
1198  return false;
1199  }
1200  // 3. No element of sequence found. Return true.
1201  return true;
1202 }
1203 
1206  return llvm::all_of(maps, [&](AffineMap map) {
1207  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1208  return isDimSequencePreserved(map, dimSequence);
1209  });
1210  });
1211 }
1212 
1213 // Return the list of dimensions of the iteration domain that can be
1214 // collapsed to allow for fusion with the a producer that is an expand_shape
1215 // operation. If all dimensions created by expansion can be collapsed in the
1216 // iteration space then the reshape is defunct.
1217 //
1218 // Example:
1219 //
1220 // ```mlir
1221 // #map = affine_map<(d0, d1) -> (d0, d1)>
1222 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1223 // %2 = tensor.empty [..] : tensor<?x4xf32>
1224 // %3 = linalg.generic {
1225 // indexing_maps = [#map, #map],
1226 // iterator_types = ["parallel" ,"parallel"]}
1227 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1228 // ```
1229 //
1230 // can be fused by collapsing the dimensions of the iteration space.
1231 //
1232 // ```mlir
1233 // #map = affine_map<(d0) -> (d0)>
1234 // %2 = tensor.empty [..] : tensor<?xf32>
1235 // %3 = linalg.generic {
1236 // indexing_maps = [#map, #map],
1237 // iterator_types = ["parallel"]}
1238 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1239 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1240 // ```
1241 //
1242 // In the following example,
1243 //
1244 // ```mlir
1245 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1246 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1247 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1248 // %2 = tensor.empty [..] : tensor<4x?xf32>
1249 // %2 = linalg.generic {
1250 // indexing_maps = [#map0, #map1],
1251 // iterator_types = ["parallel" ,"parallel"]}
1252 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1253 // ```
1254 //
1255 // the reshape cannot be fused with the generic op by collapsing the op
1256 // dimensions since the indexing maps will have to contain mods and divs
1257 // to preserve the accesses pattern. When no dimensions of the iteration
1258 // space are collapsable and empty vector is returned.
1260 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1261  ArrayRef<ReassociationIndices> reassociation) {
1262  // Some basic checks for this fusion to be valid.
1263  if (!genericOp.hasPureTensorSemantics())
1264  return {};
1265 
1266  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1267  return map.isProjectedPermutation();
1268  })) {
1269  return {};
1270  }
1271 
1272  // Compute all the loops with the reduction iterator types.
1273  SmallVector<unsigned> reductionDims;
1274  genericOp.getReductionDims(reductionDims);
1275 
1276  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1277  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1278  auto iteratorTypes = genericOp.getIteratorTypesArray();
1279  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1280  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1281  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1282 
1283  // Ignore dims that are not folded.
1284  if (foldedRangeDims.size() == 1)
1285  continue;
1286 
1287  ReassociationIndices foldedIterationSpaceDims =
1288  getDomainReassociation(indexingMap, foldedRangeDims);
1289 
1290  // Check that the folded iteration dims do not contain already processed
1291  // dims.
1292  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1293  return processedIterationDims.count(dim);
1294  }))
1295  continue;
1296 
1297  // Check that all folded iterator types are all parallel or all reductions.
1298  utils::IteratorType startIteratorType =
1299  iteratorTypes[foldedIterationSpaceDims[0]];
1300  if (!isParallelIterator(startIteratorType) &&
1301  !isReductionIterator(startIteratorType))
1302  continue;
1303  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1304  return iteratorTypes[dim] != startIteratorType;
1305  }))
1306  continue;
1307 
1308  // If the folded dimensions correspond to a "reduction" iterator type,
1309  // the folded dimensions need to be "in-order". Strictly speaking this is
1310  // not necessary, for reductions that are associative and commutative, but
1311  // using a more strict definition of reduction for now.
1312  if (isReductionIterator(startIteratorType)) {
1313  bool isContiguous = false;
1314  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1315  // Move window in `reductionDims` to start of the folded iteration dims.
1316  if (startDim.value() != foldedIterationSpaceDims[0])
1317  continue;
1318  // If sizes doesnt match, trivial not contiguous. This condition should
1319  // not be hit.
1320  if (startDim.index() + foldedIterationSpaceDims.size() >
1321  reductionDims.size())
1322  break;
1323  // Check that the contiguity is maintained.
1324  isContiguous = true;
1325  for (const auto &foldedDim :
1326  llvm::enumerate(foldedIterationSpaceDims)) {
1327  if (reductionDims[foldedDim.index() + startDim.index()] !=
1328  foldedDim.value()) {
1329  isContiguous = false;
1330  break;
1331  }
1332  }
1333  break;
1334  }
1335  if (!isContiguous)
1336  continue;
1337  }
1338 
1339  // Check that the sequence is preserved in all indexing maps.
1340  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1341  [&](AffineMap indexingMap) {
1342  return !isDimSequencePreserved(indexingMap,
1343  foldedIterationSpaceDims);
1344  }))
1345  continue;
1346 
1347  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1348  foldedIterationSpaceDims.end());
1349  iterationSpaceReassociation.emplace_back(
1350  std::move(foldedIterationSpaceDims));
1351  }
1352 
1353  return iterationSpaceReassociation;
1354 }
1355 
1356 /// Helper class to carry state while collapsing the `linalg.generic` op.
1357 namespace {
1358 class CollapsingInfo {
1359 public:
1360  LogicalResult initialize(unsigned origNumLoops,
1361  ArrayRef<ReassociationIndices> foldedIterationDims) {
1362  llvm::SmallDenseSet<int64_t, 4> processedDims;
1363  // Find all the dims that are folded.
1364  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1365  if (foldedIterationDim.empty())
1366  continue;
1367  // If the folded dims contain dims already folded, that's illegal
1368  // specification. Repetition within a list is also illegal.
1369  for (auto dim : foldedIterationDim) {
1370  if (dim >= origNumLoops)
1371  return failure();
1372  if (processedDims.count(dim))
1373  return failure();
1374  processedDims.insert(dim);
1375  }
1376  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1377  foldedIterationDim.end());
1378  }
1379  if (processedDims.size() > origNumLoops)
1380  return failure();
1381 
1382  // Add all the preserved dims of the original op as single
1383  // elements to `collapsedOpToOrigOpIterationDim`.
1384  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1385  if (processedDims.count(dim))
1386  continue;
1387  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1388  }
1389 
1390  llvm::sort(collapsedOpToOrigOpIterationDim,
1392  return lhs[0] < rhs[0];
1393  });
1394  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1395  for (const auto &foldedDims :
1396  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1397  for (const auto &dim : enumerate(foldedDims.value()))
1398  origOpToCollapsedOpIterationDim[dim.value()] =
1399  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1400  }
1401  return success();
1402  }
1403 
1404  /// Return mapping from collapsed loop domain to original loop domain.
1405  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1406  return collapsedOpToOrigOpIterationDim;
1407  }
1408 
1409  /// Return mapping from original loop domain to collapsed loop domain. The
1410  /// mapping is a pair. First value is the dimension in the collapsed loop that
1411  /// the original loop is mapped to. Second is the relative position in folded
1412  /// list of this domain. For example if the original loop domain is 3D, and
1413  /// the collapsed loop domain is folding all of it, i.e.
1414  ///
1415  /// ```
1416  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1417  /// ```
1418  ///
1419  /// then
1420  ///
1421  /// ```
1422  /// origOpToCollapsedOpMapping[0] = {0, 0};
1423  /// origOpToCollapsedOpMapping[1] = {0, 1};
1424  /// origOpToCollapsedOpMapping[2] = {0, 2};
1425  /// origOpToCollapsedOpMapping[3] = {1, 0};
1426  /// origOpToCollapsedOpMapping[4] = {1, 1};
1427  /// ```
1428  ///
1429  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1430  return origOpToCollapsedOpIterationDim;
1431  }
1432 
1433  /// Return the collapsed op iteration domain rank.
1434  unsigned getCollapsedOpIterationRank() const {
1435  return collapsedOpToOrigOpIterationDim.size();
1436  }
1437 
1438 private:
1439  /// Map from the iteration domain index in collapsed op to the iteration
1440  /// domain indices in the original op.
1441  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1442 
1443  /// Map from iteration domain index in the original op to the iteration domain
1444  /// index in the collapsed op.
1445  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1446 };
1447 } // namespace
1448 
1449 /// Get the iterator types for the collapsed operation given the original
1450 /// iterator types and collapsed dimensions.
1453  const CollapsingInfo &collapsingInfo) {
1454  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1455  for (ReassociationIndicesRef foldedIterDims :
1456  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1457  assert(!foldedIterDims.empty() &&
1458  "reassociation indices expected to have non-empty sets");
1459  // Just pick the iterator type of the first folded dim. Pre-condition checks
1460  // expected to have checked that iterator types of all folded dimensions are
1461  // the same.
1462  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1463  }
1464  return collapsedIteratorTypes;
1465 }
1466 
1467 /// Compute the indexing map in the collapsed op that corresponds to the given
1468 /// `indexingMap` of the original operation.
1469 static AffineMap
1471  const CollapsingInfo &collapsingInfo) {
1472  MLIRContext *context = indexingMap.getContext();
1473  assert(indexingMap.isProjectedPermutation() &&
1474  "expected indexing map to be projected permutation");
1475  SmallVector<AffineExpr> resultExprs;
1476  auto origOpToCollapsedOpMapping =
1477  collapsingInfo.getOrigOpToCollapsedOpMapping();
1478  for (auto expr : indexingMap.getResults()) {
1479  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1480  // If the dim is not the first of the collapsed dim, do nothing.
1481  if (origOpToCollapsedOpMapping[dim].second != 0)
1482  continue;
1483  // The next n-dims are guaranteed to be collapsed. So just use the
1484  // iteration dimension of the collapsed op.
1485  resultExprs.push_back(
1486  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1487  }
1488  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1489  resultExprs, context);
1490 }
1491 
1492 /// Return the `reassociation` indices to use to collapse the operand when the
1493 /// iteration space of a generic op is collapsed.
1496  const CollapsingInfo &collapsingInfo) {
1497  unsigned counter = 0;
1498  SmallVector<ReassociationIndices> operandReassociation;
1499  auto origOpToCollapsedOpMapping =
1500  collapsingInfo.getOrigOpToCollapsedOpMapping();
1501  auto collapsedOpToOrigOpMapping =
1502  collapsingInfo.getCollapsedOpToOrigOpMapping();
1503  while (counter < indexingMap.getNumResults()) {
1504  unsigned dim =
1505  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1506  // This is the start of a collapsed dimensions of the iteration that
1507  // is gauranteed to be preserved in the indexing map. The number of folded
1508  // dims is obtained from the collapsed op to original op mapping.
1509  unsigned numFoldedDims =
1510  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1511  .size();
1512  if (origOpToCollapsedOpMapping[dim].second == 0) {
1513  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1514  operandReassociation.emplace_back(range.begin(), range.end());
1515  }
1516  counter += numFoldedDims;
1517  }
1518  return operandReassociation;
1519 }
1520 
1521 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1522 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1523  OpOperand *opOperand,
1524  const CollapsingInfo &collapsingInfo,
1525  OpBuilder &builder) {
1526  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1527  SmallVector<ReassociationIndices> operandReassociation =
1528  getOperandReassociation(indexingMap, collapsingInfo);
1529 
1530  // If the number of entries in the reassociation for the operand is same as
1531  // the number of results of the indexing map, then nothing to do for this
1532  // operand.
1533  Value operand = opOperand->get();
1534  if (operandReassociation.size() == indexingMap.getNumResults())
1535  return operand;
1536 
1537  // Insert a reshape to collapse the dimensions.
1538  if (isa<MemRefType>(operand.getType())) {
1539  return builder
1540  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1541  .getResult();
1542  }
1543  return builder
1544  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1545  .getResult();
1546 }
1547 
1548 /// Modify the `linalg.index` operations in the original generic op, to its
1549 /// value in the collapsed operation.
1551  const CollapsingInfo &collapsingInfo,
1552  ValueRange loopRange,
1553  RewriterBase &rewriter) {
1554  OpBuilder::InsertionGuard g(rewriter);
1555  rewriter.setInsertionPointToStart(block);
1556 
1557  // Collect all the original index ops.
1558  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1559 
1560  // For each folded dimension list resolve the original induction variable
1561  // values in terms of the folded dimension induction variable.
1562  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1563  // can be inverted to
1564  // i2 = i_{folded} % d2
1565  // i1 = (i_{folded} / d2) % d1
1566  // i0 = i_{folded} / (d1 * d2)
1567  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1568  for (auto foldedDims :
1569  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1570  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1571  Value newIndexVal =
1572  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1573  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1574  indexReplacementVals[dim] =
1575  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1576  newIndexVal =
1577  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1578  }
1579  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1580  }
1581 
1582  for (auto indexOp : indexOps) {
1583  auto dim = indexOp.getDim();
1584  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1585  }
1586 }
1587 
1589  const CollapsingInfo &collapsingInfo,
1590  RewriterBase &rewriter,
1591  SmallVectorImpl<Value> &inputOperands,
1592  SmallVectorImpl<Value> &outputOperands,
1593  SmallVectorImpl<Type> &resultTypes) {
1594  Location loc = op->getLoc();
1595  inputOperands =
1596  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1597  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1598  rewriter);
1599  });
1600 
1601  // Get the output operands and result types.
1602  resultTypes.reserve(op.getNumDpsInits());
1603  outputOperands.reserve(op.getNumDpsInits());
1604  for (OpOperand &output : op.getDpsInitsMutable()) {
1605  Value newOutput =
1606  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1607  outputOperands.push_back(newOutput);
1608  // If the op has "buffer semantics", then the init operands are ranked
1609  // memrefs and the op has no results.
1610  if (!op.hasPureBufferSemantics())
1611  resultTypes.push_back(newOutput.getType());
1612  }
1613 }
1614 
1615 /// Clone a `LinalgOp` to a collapsed version of same name
1616 template <typename OpTy>
1617 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1618  const CollapsingInfo &collapsingInfo) {
1619  return nullptr;
1620 }
1621 
1622 /// Collapse any `LinalgOp` that does not require any specialization such as
1623 /// indexing_maps, iterator_types, etc.
1624 template <>
1625 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1626  const CollapsingInfo &collapsingInfo) {
1627  SmallVector<Value> inputOperands, outputOperands;
1628  SmallVector<Type> resultTypes;
1629  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1630  outputOperands, resultTypes);
1631 
1632  return clone(
1633  rewriter, origOp, resultTypes,
1634  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1635 }
1636 
1637 /// Collapse a `GenericOp`
1638 template <>
1640  GenericOp origOp,
1641  const CollapsingInfo &collapsingInfo) {
1642  SmallVector<Value> inputOperands, outputOperands;
1643  SmallVector<Type> resultTypes;
1644  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1645  outputOperands, resultTypes);
1646  SmallVector<AffineMap> indexingMaps(
1647  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1648  return getCollapsedOpIndexingMap(map, collapsingInfo);
1649  }));
1650 
1652  origOp.getIteratorTypesArray(), collapsingInfo));
1653 
1654  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1655  origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1656  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1657  Block *origOpBlock = &origOp->getRegion(0).front();
1658  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1659  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1660  collapsedOpBlock->getArguments());
1661  return collapsedOp;
1662 }
1663 
1664 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1665  RewriterBase &rewriter) {
1666  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1667  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1668  } else {
1669  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1670  }
1671 }
1672 
1673 /// Implementation of fusion with reshape operation by collapsing dimensions.
1674 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1675  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1676  RewriterBase &rewriter) {
1677  // Bail on trivial no-op cases.
1678  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1679  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1680  return foldedDims.size() <= 1;
1681  }))
1682  return failure();
1683 
1684  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1685  if (hasPureBufferSemantics &&
1686  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1687  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1688  if (!memRefToCollapse)
1689  return true;
1690 
1691  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1692  memRefToCollapse, foldedIterationDims);
1693  }))
1694  return rewriter.notifyMatchFailure(op,
1695  "memref is not guaranteed collapsible");
1696 
1697  CollapsingInfo collapsingInfo;
1698  if (failed(
1699  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1700  return rewriter.notifyMatchFailure(
1701  op, "illegal to collapse specified dimensions");
1702  }
1703 
1704  // Bail on non-canonical ranges.
1705  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1706  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1707  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1708  return cast<IntegerAttr>(attr).getInt() == value;
1709  llvm::APInt actual;
1710  return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1711  actual.getSExtValue() == value;
1712  };
1713  if (!llvm::all_of(loopRanges, [&](Range range) {
1714  return opFoldIsConstantValue(range.offset, 0) &&
1715  opFoldIsConstantValue(range.stride, 1);
1716  })) {
1717  return rewriter.notifyMatchFailure(
1718  op, "expected all loop ranges to have zero start and unit stride");
1719  }
1720 
1721  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1722 
1723  Location loc = op->getLoc();
1724  if (collapsedOp.hasIndexSemantics()) {
1725  // Collect the loop range of the generic op.
1726  OpBuilder::InsertionGuard g(rewriter);
1727  rewriter.setInsertionPoint(collapsedOp);
1728  SmallVector<Value> loopBound =
1729  llvm::map_to_vector(loopRanges, [&](Range range) {
1730  return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1731  });
1732  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1733  collapsingInfo, loopBound, rewriter);
1734  }
1735 
1736  // Insert expanding reshape for the result to get back the original result
1737  // type.
1738  SmallVector<Value> results;
1739  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1740  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1741  auto originalResultType =
1742  cast<ShapedType>(originalResult.value().getType());
1743  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1744  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1745  AffineMap indexingMap =
1746  op.getIndexingMapMatchingResult(originalResult.value());
1747  SmallVector<ReassociationIndices> reassociation =
1748  getOperandReassociation(indexingMap, collapsingInfo);
1749  Value result;
1750  if (isa<MemRefType>(collapsedOpResult.getType())) {
1751  MemRefType expandShapeResultType = MemRefType::get(
1752  originalResultType.getShape(), originalResultType.getElementType());
1753  result = rewriter.create<memref::ExpandShapeOp>(
1754  loc, expandShapeResultType, collapsedOpResult, reassociation);
1755  } else {
1756  result = rewriter.create<tensor::ExpandShapeOp>(
1757  loc, originalResultType, collapsedOpResult, reassociation);
1758  }
1759  results.push_back(result);
1760  } else {
1761  results.push_back(collapsedOpResult);
1762  }
1763  }
1764  return CollapseResult{results, collapsedOp};
1765 }
1766 
1767 namespace {
1768 
1769 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1770 /// contracting dimensions of the loop.
1771 class FoldWithProducerReshapeOpByCollapsing
1772  : public OpRewritePattern<GenericOp> {
1773 public:
1774  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1775  ControlFusionFn foldReshapes,
1776  PatternBenefit benefit = 1)
1777  : OpRewritePattern<GenericOp>(context, benefit),
1778  controlFoldingReshapes(std::move(foldReshapes)) {}
1779 
1780  LogicalResult matchAndRewrite(GenericOp genericOp,
1781  PatternRewriter &rewriter) const override {
1782  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1783  tensor::ExpandShapeOp reshapeOp =
1784  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1785  if (!reshapeOp)
1786  continue;
1787 
1788  SmallVector<ReassociationIndices> collapsableIterationDims =
1789  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1790  reshapeOp.getReassociationIndices());
1791  if (collapsableIterationDims.empty() ||
1792  !controlFoldingReshapes(&opOperand)) {
1793  continue;
1794  }
1795 
1796  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1797  genericOp, collapsableIterationDims, rewriter);
1798  if (!collapseResult) {
1799  return rewriter.notifyMatchFailure(
1800  genericOp, "failed to do the fusion by collapsing transformation");
1801  }
1802 
1803  rewriter.replaceOp(genericOp, collapseResult->results);
1804  return success();
1805  }
1806  return failure();
1807  }
1808 
1809 private:
1810  ControlFusionFn controlFoldingReshapes;
1811 };
1812 
1813 class FoldPadWithProducerReshapeOpByCollapsing
1814  : public OpRewritePattern<tensor::PadOp> {
1815 public:
1816  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1817  ControlFusionFn foldReshapes,
1818  PatternBenefit benefit = 1)
1819  : OpRewritePattern<tensor::PadOp>(context, benefit),
1820  controlFoldingReshapes(std::move(foldReshapes)) {}
1821 
1822  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1823  PatternRewriter &rewriter) const override {
1824  tensor::ExpandShapeOp reshapeOp =
1825  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1826  if (!reshapeOp)
1827  return failure();
1828  if (!reshapeOp->hasOneUse())
1829  return failure();
1830 
1831  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1832  return rewriter.notifyMatchFailure(padOp,
1833  "fusion blocked by control function");
1834  }
1835 
1836  ArrayRef<int64_t> low = padOp.getStaticLow();
1837  ArrayRef<int64_t> high = padOp.getStaticHigh();
1838  SmallVector<ReassociationIndices> reassociations =
1839  reshapeOp.getReassociationIndices();
1840 
1841  for (auto reInd : reassociations) {
1842  if (reInd.size() == 1)
1843  continue;
1844  if (llvm::any_of(reInd, [&](int64_t ind) {
1845  return low[ind] != 0 || high[ind] != 0;
1846  })) {
1847  return failure();
1848  }
1849  }
1850 
1851  SmallVector<OpFoldResult> newLow, newHigh;
1852  RankedTensorType collapsedType = reshapeOp.getSrcType();
1853  RankedTensorType paddedType = padOp.getResultType();
1854  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1855  SmallVector<OpFoldResult> expandedPaddedSizes(
1856  getMixedValues(reshapeOp.getStaticOutputShape(),
1857  reshapeOp.getOutputShape(), rewriter));
1858  AffineExpr d0, d1, d2;
1859  bindDims(rewriter.getContext(), d0, d1, d2);
1860  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1861  Location loc = reshapeOp->getLoc();
1862  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1863  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1864  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1865  if (reInd.size() == 1) {
1866  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1868  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1869  expandedPaddedSizes[reInd[0]] = paddedSize;
1870  }
1871  newLow.push_back(l);
1872  newHigh.push_back(h);
1873  }
1874 
1875  RankedTensorType collapsedPaddedType =
1876  paddedType.clone(collapsedPaddedShape);
1877  auto newPadOp = rewriter.create<tensor::PadOp>(
1878  loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1879  padOp.getConstantPaddingValue(), padOp.getNofold());
1880 
1881  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1882  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1883  expandedPaddedSizes);
1884 
1885  return success();
1886  }
1887 
1888 private:
1889  ControlFusionFn controlFoldingReshapes;
1890 };
1891 
1892 /// Pattern to collapse dimensions.
1893 template <typename LinalgType>
1894 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1895 public:
1896  CollapseLinalgDimensions(MLIRContext *context,
1897  GetCollapsableDimensionsFn collapseDimensions,
1898  PatternBenefit benefit = 1)
1899  : OpRewritePattern<LinalgType>(context, benefit),
1900  controlCollapseDimension(std::move(collapseDimensions)) {}
1901 
1902  LogicalResult matchAndRewrite(LinalgType op,
1903  PatternRewriter &rewriter) const override {
1904  SmallVector<ReassociationIndices> collapsableIterationDims =
1905  controlCollapseDimension(op);
1906  if (collapsableIterationDims.empty())
1907  return failure();
1908 
1909  // Check if the specified list of dimensions to collapse is a valid list.
1910  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1911  collapsableIterationDims)) {
1912  return rewriter.notifyMatchFailure(
1913  op, "specified dimensions cannot be collapsed");
1914  }
1915 
1916  std::optional<CollapseResult> collapseResult =
1917  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1918  if (!collapseResult) {
1919  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
1920  }
1921  rewriter.replaceOp(op, collapseResult->results);
1922  return success();
1923  }
1924 
1925 private:
1926  GetCollapsableDimensionsFn controlCollapseDimension;
1927 };
1928 
1929 } // namespace
1930 
1931 //===---------------------------------------------------------------------===//
1932 // Methods and patterns that fuse constants with linalg.generic operations.
1933 //===---------------------------------------------------------------------===//
1934 
1935 namespace {
1936 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1937 /// handle cases where the constant is not single-valued.
1938 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1939 public:
1940  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1941  : OpRewritePattern<GenericOp>(context, benefit) {}
1942 
1943  LogicalResult matchAndRewrite(GenericOp genericOp,
1944  PatternRewriter &rewriter) const override {
1945  if (!genericOp.hasPureTensorSemantics())
1946  return failure();
1947  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1948  Operation *def = opOperand->get().getDefiningOp();
1949  TypedAttr constantAttr;
1950  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1951  {
1952  DenseElementsAttr splatAttr;
1953  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1954  splatAttr.isSplat() &&
1955  splatAttr.getType().getElementType().isIntOrFloat()) {
1956  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1957  return true;
1958  }
1959  }
1960  {
1961  IntegerAttr intAttr;
1962  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1963  constantAttr = intAttr;
1964  return true;
1965  }
1966  }
1967  {
1968  FloatAttr floatAttr;
1969  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1970  constantAttr = floatAttr;
1971  return true;
1972  }
1973  }
1974  return false;
1975  };
1976 
1977  auto resultValue = dyn_cast<OpResult>(opOperand->get());
1978  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1979  continue;
1980 
1981  // The operands and the indexing_maps of the fused operation the same as
1982  // the operands and indexing_maps of the generic operations with the
1983  // values at the constant index dropped.
1984  SmallVector<AffineMap> fusedIndexMaps;
1985  SmallVector<Value> fusedOperands;
1986  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1987  fusedIndexMaps.reserve(genericOp->getNumOperands());
1988  fusedOperands.reserve(genericOp.getNumDpsInputs());
1989  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1990  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1991  if (inputOperand == opOperand)
1992  continue;
1993  Value inputValue = inputOperand->get();
1994  fusedIndexMaps.push_back(
1995  genericOp.getMatchingIndexingMap(inputOperand));
1996  fusedOperands.push_back(inputValue);
1997  fusedLocs.push_back(inputValue.getLoc());
1998  }
1999  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2000  fusedIndexMaps.push_back(
2001  genericOp.getMatchingIndexingMap(&outputOperand));
2002 
2003  // Check if the operation shapes to loops map is computable.
2004  if (!inversePermutation(
2005  concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2006  return rewriter.notifyMatchFailure(
2007  genericOp, "fused op loop bound computation failed");
2008  }
2009 
2010  // Create a constant scalar value from the splat constant.
2011  Value scalarConstant =
2012  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2013 
2014  SmallVector<Value> outputOperands = genericOp.getOutputs();
2015  auto fusedOp = rewriter.create<GenericOp>(
2016  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2017  /*inputs=*/fusedOperands,
2018  /*outputs=*/outputOperands,
2019  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2020  genericOp.getIteratorTypes(),
2021  /*doc=*/nullptr,
2022  /*library_call=*/nullptr);
2023 
2024  // Map the block argument corresponding to the replaced argument with the
2025  // scalar constant.
2026  Region &region = genericOp->getRegion(0);
2027  Block &entryBlock = *region.begin();
2028  IRMapping mapping;
2029  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2030  scalarConstant);
2031  Region &fusedRegion = fusedOp->getRegion(0);
2032  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2033  mapping);
2034  rewriter.replaceOp(genericOp, fusedOp->getResults());
2035  return success();
2036  }
2037  return failure();
2038  }
2039 };
2040 
2041 } // namespace
2042 
2043 //===---------------------------------------------------------------------===//
2044 // Miscellaneous patterns that help fusion.
2045 //===---------------------------------------------------------------------===//
2046 
2047 namespace {
2048 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2049 /// value of the `outs` operand is not used within the op. This is only
2050 /// implemented for `linalg.generic` operations for now, but should hold for all
2051 /// linalg structured ops.
2052 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2054 
2055  LogicalResult matchAndRewrite(GenericOp op,
2056  PatternRewriter &rewriter) const override {
2057  rewriter.startOpModification(op);
2058  bool modifiedOutput = false;
2059  Location loc = op.getLoc();
2060  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2061  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2062  Value operandVal = opOperand.get();
2063  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2064  if (!operandType)
2065  continue;
2066 
2067  // If outs is sparse, leave it to the sparsifier.
2069  continue;
2070 
2071  // If outs is already an `empty` operation, nothing to do.
2072  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2073  if (definingOp)
2074  continue;
2075  modifiedOutput = true;
2076  SmallVector<OpFoldResult> mixedSizes =
2077  tensor::getMixedSizes(rewriter, loc, operandVal);
2078  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2079  loc, mixedSizes, operandType.getElementType());
2080  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2081  }
2082  }
2083  if (!modifiedOutput) {
2084  rewriter.cancelOpModification(op);
2085  return failure();
2086  }
2087  rewriter.finalizeOpModification(op);
2088  return success();
2089  }
2090 };
2091 
2092 /// Fold linalg.fill into linalg.generic
2093 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2095 
2096  LogicalResult matchAndRewrite(GenericOp genericOp,
2097  PatternRewriter &rewriter) const override {
2098  if (!genericOp.hasPureTensorSemantics())
2099  return failure();
2100  bool fillFound = false;
2101  Block &payload = genericOp.getRegion().front();
2102  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2103  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2104  continue;
2105  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2106  if (!fillOp)
2107  continue;
2108  fillFound = true;
2109  Value fillVal = fillOp.value();
2110  auto resultType =
2111  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2112  Value convertedVal =
2113  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2114  /*isUnsignedCast =*/false);
2115  rewriter.replaceAllUsesWith(
2116  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2117  }
2118  return success(fillFound);
2119  }
2120 };
2121 } // namespace
2122 
2125  const ControlFusionFn &controlFoldingReshapes) {
2126  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2127  controlFoldingReshapes);
2128  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2129  controlFoldingReshapes);
2130  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2131  controlFoldingReshapes);
2132 }
2133 
2136  const ControlFusionFn &controlFoldingReshapes) {
2137  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2138  controlFoldingReshapes);
2139  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2140  patterns.getContext(), controlFoldingReshapes);
2141 }
2142 
2145  const ControlFusionFn &controlElementwiseOpsFusion) {
2146  auto *context = patterns.getContext();
2147  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2148  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2149  RemoveOutsDependency>(context);
2150  // Add the patterns that clean up dead operands and results.
2152 }
2153 
2156  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2157  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2158  CollapseLinalgDimensions<linalg::CopyOp>>(
2159  patterns.getContext(), controlCollapseDimensions);
2160 }
2161 
2162 //===---------------------------------------------------------------------===//
2163 // Passes
2164 //===---------------------------------------------------------------------===//
2165 
2166 namespace {
2167 
2168 /// Pass that fuses generic ops on tensors. Used only for testing.
2169 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2170 // patterns added here heavily depends on the cost function used. Having an
2171 // opinionated pass of this form is not recommended. Deprecate this pass in
2172 // favor of test passes that check the functionality of each of the patterns
2173 // added here individually.
2174 struct LinalgElementwiseOpFusionPass
2175  : public impl::LinalgElementwiseOpFusionPassBase<
2176  LinalgElementwiseOpFusionPass> {
2177  using impl::LinalgElementwiseOpFusionPassBase<
2178  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2179  void runOnOperation() override {
2180  Operation *op = getOperation();
2181  MLIRContext *context = op->getContext();
2182  RewritePatternSet patterns(context);
2183 
2184  // Add folding with reshape by expansion patterns.
2185  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2186  Operation *producer = fusedOperand->get().getDefiningOp();
2187  return producer && producer->hasOneUse();
2188  };
2189 
2190  // Add elementwise op fusion patterns.
2194 
2195  // General canonicalization patterns.
2196  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2197  GenericOp::getCanonicalizationPatterns(patterns, context);
2198  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2199  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2200  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2201  patterns);
2202 
2203  // Add constant folding patterns.
2205 
2206  // Use TopDownTraversal for compile time reasons
2207  GreedyRewriteConfig grc;
2208  grc.useTopDownTraversal = true;
2209  (void)applyPatternsGreedily(op, std::move(patterns), grc);
2210  }
2211 };
2212 
2213 } // namespace
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Expanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static LogicalResult validateDynamicDimExpansion(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Checks if a single dynamic dimension expanded into multiple dynamic dimensions.
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:615
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:630
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1769
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1800
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:836
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:502
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:504
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.