MLIR  21.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Support/LLVM.h"
29 #include <optional>
30 #include <utility>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 //===---------------------------------------------------------------------===//
41 // Methods and patterns that fuse elementwise `linalg.generic` operations.
42 //===---------------------------------------------------------------------===//
43 
44 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45 /// the `producer` to use in the fused operation given the indexing map of the
46 /// result of the producer in the consumer.
48  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49  AffineMap fusedConsumerArgIndexMap) {
50  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51  // from consumer loop -> consumer arg tensor index/producer result tensor
52  // index. The fused loop is same as the consumer loop. For each producer arg
53  // the indexing map to be computed is a map from consumer loop -> producer
54  // arg tensor index.
55  // producerResultIndexMap is a map from producer loop -> tensor index.
56  // Compute the inverse to get map from tensor index -> producer loop.
57  // The inverse is a map from producer result tensor index -> producer loop.
58  AffineMap invProducerResultIndexMap =
59  inversePermutation(producerResultIndexMap);
60  assert(invProducerResultIndexMap &&
61  "expected producer result indexing map to be invertible");
62 
63  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64  // argMap is a map from producer loop -> producer arg tensor index.
65  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66 
67  // Compose argMap with invProducerResultIndexMap to get a map from
68  // producer result tensor index -> producer arg tensor index.
69  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70 
71  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72  // consumer loop/ fused loop -> producer arg tensor index.
73  return t1.compose(fusedConsumerArgIndexMap);
74 }
75 
76 // Checks if the given operand can be dropped, and the remaining operands
77 // of the fused producer & consumer after the fusion can still compute the
78 // bounds of the op.
80  GenericOp producer, GenericOp consumer,
81  ArrayRef<OpOperand *> opOperandsToIgnore) {
82  SmallVector<AffineMap> indexingMaps;
83 
84  SmallVector<GenericOp> ops = {producer, consumer};
85  for (auto &op : ops) {
86  for (auto &opOperand : op->getOpOperands()) {
87  if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88  continue;
89  }
90  indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91  }
92  }
93  if (indexingMaps.empty()) {
94  // If there are no indexing maps, the operand can only be dropped
95  // if neither op has loops.
96  return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97  }
98 
99  // The concatanation of the remained indexing maps must be invertible, so
100  // the bounds of the op can be still computed after dropping the selected
101  // operand. inversePermutation returns an empty AffineMap in case the
102  // concatanated indexing maps are not invertible.
104  indexingMaps, producer.getContext())) != AffineMap();
105 }
106 
107 /// Returns a set of indices of the producer's results which would
108 /// be preserved after the fusion.
109 /// * There is a chance that the implementation of the transformation does not
110 /// agree with the result of this method. This function gives a prediction based
111 /// on an optimized fusion.
113  GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114  llvm::SmallDenseSet<int> preservedProducerResults;
115  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116 
117  // The fusedOperand will be removed during the fusion
118  opOperandsToIgnore.emplace_back(fusedOperand);
119 
120  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122  opOperandsToIgnore.emplace_back(outputOperand);
123  if (producer.payloadUsesValueFromOperand(outputOperand) ||
124  !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
125  opOperandsToIgnore) ||
126  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127  return user != consumer.getOperation();
128  })) {
129  preservedProducerResults.insert(producerResult.index());
130 
131  // In case the operand can't be dropped
132  (void)opOperandsToIgnore.pop_back_val();
133  }
134  }
135  return preservedProducerResults;
136 }
137 
138 /// Conditions for elementwise fusion of generic operations.
140  if (!fusedOperand)
141  return false;
142 
143  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145 
146  // Check producer and consumer are generic ops.
147  if (!producer || !consumer)
148  return false;
149 
150  // Consumer can have mixed semantics, just check operand itself has tensor
151  // type. Producer must have full tensor semantics to avoid potential
152  // aliasing between producer and consumer memrefs.
153  if (!producer.hasPureTensorSemantics() ||
154  !isa<RankedTensorType>(fusedOperand->get().getType()))
155  return false;
156 
157  // Verify that
158  // - the producer has all "parallel" iterator type.
159  if (producer.getNumParallelLoops() != producer.getNumLoops())
160  return false;
161 
162  // Only allow fusing the producer of an input operand for now.
163  // TODO: allow fusing the producer of an output operand.
164  if (!consumer.isDpsInput(fusedOperand))
165  return false;
166 
167  // Get the consumer index map. The number of results of the consumer index
168  // map must match the number of loops of the producer.
169  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171  return false;
172 
173  // Finally the index_map for the result must be invertible. For now just
174  // verify it is a permutation.
175  AffineMap producerResultIndexMap =
176  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
177  if (!producerResultIndexMap.isPermutation())
178  return false;
179 
180  // Ensure that the fusion does not remove size information required to
181  // get the loop bounds. For non-reduction generics, this is trivially the
182  // case due to the output operand. For reductions, we need to check that after
183  // the fusion, each loop dimension has at least one input that defines it.
184  if ((consumer.getNumReductionLoops())) {
185  BitVector coveredDims(consumer.getNumLoops(), false);
186 
187  auto addToCoveredDims = [&](AffineMap map) {
188  for (auto result : map.getResults())
189  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
190  coveredDims[dimExpr.getPosition()] = true;
191  };
192 
193  for (auto pair :
194  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195  Value operand = std::get<0>(pair);
196  if (operand == fusedOperand->get())
197  continue;
198  AffineMap operandMap = std::get<1>(pair);
199  addToCoveredDims(operandMap);
200  }
201 
202  for (OpOperand *operand : producer.getDpsInputOperands()) {
203  AffineMap newIndexingMap =
205  operand, producerResultIndexMap, consumerIndexMap);
206  addToCoveredDims(newIndexingMap);
207  }
208  if (!coveredDims.all())
209  return false;
210  }
211 
212  return true;
213 }
214 
215 /// Generate the region of the fused tensor operation. The region of the fused
216 /// op must be empty.
218  RewriterBase &rewriter, GenericOp fusedOp,
219  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
220  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
221  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
222  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
223  // Build the region of the fused op.
224  Block &producerBlock = producer->getRegion(0).front();
225  Block &consumerBlock = consumer->getRegion(0).front();
226  OpBuilder::InsertionGuard guard(rewriter);
227  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
228  IRMapping mapper;
229 
230  // 2. Add an index operation for every fused loop dimension and use the
231  // `consumerToProducerLoopsMap` to map the producer indices.
232  if (producer.hasIndexSemantics()) {
233  // Add an index operation for every fused loop dimension.
234  unsigned numFusedOpLoops =
235  std::max(producer.getNumLoops(), consumer.getNumLoops());
236  SmallVector<Value> fusedIndices;
237  fusedIndices.reserve(numFusedOpLoops);
238  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239  std::back_inserter(fusedIndices), [&](uint64_t dim) {
240  return rewriter.create<IndexOp>(producer.getLoc(), dim);
241  });
242  for (IndexOp indexOp :
243  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244  Value newIndex = rewriter.create<affine::AffineApplyOp>(
245  producer.getLoc(),
246  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247  mapper.map(indexOp.getResult(), newIndex);
248  }
249  }
250  // TODO: allow fusing the producer of an output operand.
251  assert(consumer.isDpsInput(fusedOperand) &&
252  "expected producer of input operand");
253  // 3. Consumer input operands up to consumerIdx (exclusive).
254  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255  fusedOperand->getOperandNumber())) // input assumption.
256  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257 
258  // Replacing consumerIdx requires getting the cloned, yielded, value from
259  // the (cloned) producer block. This happens in step 9.
260 
261  // 4. Splice in producer's input operands.
262  for (BlockArgument bbArg :
263  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265 
266  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267  for (BlockArgument bbArg :
268  consumerBlock.getArguments()
269  .take_front(consumer.getNumDpsInputs())
270  .drop_front(fusedOperand->getOperandNumber() + 1))
271  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272 
273  // 6. All of the producer's output operands
274  for (const auto &bbArg : llvm::enumerate(
275  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276  if (!preservedProducerResults.count(bbArg.index()))
277  continue;
278  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279  bbArg.value().getLoc()));
280  }
281 
282  // 7. All of consumer's output operands.
283  for (BlockArgument bbArg :
284  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286 
287  // 8. Clone all producer operations except for the yield and index operations
288  // to the fused operation.
289  for (auto &op : producerBlock.without_terminator()) {
290  if (!isa<IndexOp>(op))
291  rewriter.clone(op, mapper);
292  }
293  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294  // forward the yield operand.
295  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296  unsigned producerResultNumber =
297  cast<OpResult>(fusedOperand->get()).getResultNumber();
298  Value replacement =
299  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300 
301  // Sanity checks, if replacement is not already in the mapper then it must be
302  // produced outside.
303  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304  if (auto bb = dyn_cast<BlockArgument>(replacement))
305  assert(bb.getOwner() != &producerBlock &&
306  "yielded block argument must have been mapped");
307  else
308  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309  "yielded value must have been mapped");
310  }
311  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
312  replacement);
313  // 10. Clone operations from the consumer to the fused op.
314  for (auto &op : consumerBlock.without_terminator())
315  rewriter.clone(op, mapper);
316 
317  // 11. Include the final yield (which is the remapped values for all the
318  // yield)
319  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320  SmallVector<Value> fusedYieldValues;
321  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322  consumerYieldOp.getNumOperands());
323  for (const auto &producerYieldVal :
324  llvm::enumerate(producerYieldOp.getOperands())) {
325  if (preservedProducerResults.count(producerYieldVal.index()))
326  fusedYieldValues.push_back(
327  mapper.lookupOrDefault(producerYieldVal.value()));
328  }
329  for (auto consumerYieldVal : consumerYieldOp.getOperands())
330  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
332 
333  // Sanity checks.
334  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335  "Ill-formed GenericOp region");
336 }
337 
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340  OpOperand *fusedOperand) {
341  assert(areElementwiseOpsFusable(fusedOperand) &&
342  "expected elementwise operation pre-conditions to pass");
343  auto producerResult = cast<OpResult>(fusedOperand->get());
344  auto producer = cast<GenericOp>(producerResult.getOwner());
345  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346  // TODO: allow fusing the producer of an output operand.
347  assert(consumer.isDpsInput(fusedOperand) &&
348  "expected producer of input operand");
349  /// Find the results of the producer that have uses outside of the consumer,
350  /// after the fusion.
351  llvm::SmallDenseSet<int> preservedProducerResults =
353  fusedOperand);
354 
355  // Compute the fused operands list and indexing maps.
356  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357  SmallVector<Type> fusedResultTypes;
358  SmallVector<AffineMap> fusedIndexMaps;
359  fusedInputOperands.reserve(producer.getNumDpsInputs() +
360  consumer.getNumDpsInputs());
361  fusedOutputOperands.reserve(preservedProducerResults.size() +
362  consumer.getNumDpsInits());
363  fusedResultTypes.reserve(preservedProducerResults.size() +
364  consumer.getNumDpsInits());
365  fusedIndexMaps.reserve(producer->getNumOperands() +
366  consumer->getNumOperands());
367  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369  auto consumerInputs = consumer.getDpsInputOperands();
370  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371  return operand == fusedOperand;
372  });
373  assert(it != consumerInputs.end() && "expected to find the consumer operand");
374  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375  fusedInputOperands.push_back(opOperand->get());
376  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377  }
378  // 4. Splice in producer's input operands/maps.
379  AffineMap producerResultIndexMap =
380  producer.getIndexingMapMatchingResult(producerResult);
381  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382  fusedInputOperands.push_back(opOperand->get());
383  // Compute indexing maps for the producer args in the fused operation.
385  opOperand, producerResultIndexMap,
386  consumer.getMatchingIndexingMap(fusedOperand));
387  fusedIndexMaps.push_back(map);
388  }
389  // 5. Remaining consumer's input operands/maps (drop past index
390  // `consumerIdx`).
391  for (OpOperand *opOperand :
392  llvm::make_range(std::next(it), consumerInputs.end())) {
393  fusedInputOperands.push_back(opOperand->get());
394  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395  }
396 
397  // 6. Collect all of the producer outputs.
398  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399  if (!preservedProducerResults.count(opOperand.index()))
400  continue;
401 
402  fusedOutputOperands.push_back(opOperand.value().get());
404  &opOperand.value(), producerResultIndexMap,
405  consumer.getMatchingIndexingMap(fusedOperand));
406  fusedIndexMaps.push_back(map);
407  fusedResultTypes.push_back(opOperand.value().get().getType());
408  }
409 
410  // 7. All of consumer's output operands (skip operands: added by the builder).
411  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412  fusedOutputOperands.push_back(opOperand.get());
413  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414  Type resultType = opOperand.get().getType();
415  if (!isa<MemRefType>(resultType))
416  fusedResultTypes.push_back(resultType);
417  }
418 
419  // Generate the fused op.
420  auto fusedOp = rewriter.create<GenericOp>(
421  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423  consumer.getIteratorTypes(),
424  /*doc=*/nullptr,
425  /*library_call=*/nullptr);
426  if (!fusedOp.getShapesToLoopsMap()) {
427  // Fused op has invalid indexing maps. Typically this means something is off
428  // in the input, but going ahead here would result in verification errors.
429  // So cleanup and abort.
430  rewriter.eraseOp(fusedOp);
431  return rewriter.notifyMatchFailure(
432  fusedOp, "fused op failed loop bound computation check");
433  }
434 
435  // Construct an AffineMap from consumer loops to producer loops.
436  // consumer loop -> tensor index
437  AffineMap consumerResultIndexMap =
438  consumer.getMatchingIndexingMap(fusedOperand);
439  // tensor index -> producer loop
440  AffineMap invProducerResultIndexMap =
441  inversePermutation(producerResultIndexMap);
442  assert(invProducerResultIndexMap &&
443  "expected producer result indexig map to be invertible");
444  // consumer loop -> producer loop
445  AffineMap consumerToProducerLoopsMap =
446  invProducerResultIndexMap.compose(consumerResultIndexMap);
447 
449  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450  consumer.getNumLoops(), preservedProducerResults);
452  result.fusedOp = fusedOp;
453  int resultNum = 0;
454  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455  if (preservedProducerResults.count(index))
456  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457  for (auto consumerResult : consumer->getResults())
458  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459  return result;
460 }
461 
462 namespace {
463 /// Patterns to fuse a generic op, with the producer of its operands.
464 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465 public:
466  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467  PatternBenefit benefit = 1)
468  : OpRewritePattern<GenericOp>(context, benefit),
469  controlFn(std::move(fun)) {}
470 
471  LogicalResult matchAndRewrite(GenericOp genericOp,
472  PatternRewriter &rewriter) const override {
473  // Find the first operand that is defined by another generic op on tensors.
474  for (OpOperand &opOperand : genericOp->getOpOperands()) {
475  if (!areElementwiseOpsFusable(&opOperand))
476  continue;
477  if (!controlFn(&opOperand))
478  continue;
479 
480  Operation *producer = opOperand.get().getDefiningOp();
481 
482  // Find the producer of the operand.
483  FailureOr<ElementwiseOpFusionResult> fusionResult =
484  fuseElementwiseOps(rewriter, &opOperand);
485  if (failed(fusionResult))
486  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487 
488  // Perform the fusion.
489  for (auto [origVal, replacement] : fusionResult->replacements) {
490  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491  // Only replace consumer uses.
492  return use.get().getDefiningOp() != producer;
493  });
494  }
495  rewriter.eraseOp(genericOp);
496  return success();
497  }
498  return failure();
499  }
500 
501 private:
502  ControlFusionFn controlFn;
503 };
504 } // namespace
505 
506 //===---------------------------------------------------------------------===//
507 // Methods and patterns that fuse reshape ops with elementwise operations by
508 // expanding the dimensionality of the elementwise operations.
509 //===---------------------------------------------------------------------===//
510 
511 /// Conditions for folding a structured linalg operation with a reshape op by
512 /// expanding the iteration space dimensionality for tensor operations. These
513 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514 /// the following fusion pattern.
515 ///
516 /// Consider
517 ///
518 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
521 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524 ///
525 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
526 /// is increased to match the result (operand) of the tensor.expand_shape.
527 /// The indexing_map of the fused tensor in the `linalgOp` and the
528 /// reassociation map helps compute the indexing maps of the modified op.
529 /// For the above example, based on the reassociation map it
530 /// can be concluded that
531 ///
532 /// - The loop used to access the first dimension of the fused tensor is split
533 /// into two.
534 /// - The loop used to access the second dimension of the fused tensor is kept
535 /// as is.
536 /// - The loop used to access the third dimension of the fused tensor is split
537 /// into three.
538 ///
539 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540 /// op, then
541 ///
542 /// d0 -> e0, e1
543 /// d1 -> e2, e3, e4
544 /// d2 -> e5
545 ///
546 /// substituting this, the structured op can be rewritten as
547 ///
548 /// %d = linalg.generic ins(%0, %1 : )
549 /// indexing_maps =
550 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553 ///
554 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
555 /// to make it consistent
556 ///
557 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561 ///
562 /// The added reshapes are again expanding patterns, so they will get fused
563 /// with its producers if possible.
564 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565  OpOperand *fusableOpOperand) {
566  // Is fusable only if:
567  // - All the indexing maps for operands and results are projected
568  // permutations.
569  // - The fused tensor is not a scalar.
570  SmallVector<utils::IteratorType> iteratorTypes =
571  linalgOp.getIteratorTypesArray();
572  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573  return linalgOp.hasPureTensorSemantics() &&
574  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575  [](Attribute attr) {
576  return cast<AffineMapAttr>(attr)
577  .getValue()
578  .isProjectedPermutation();
579  }) &&
580  operandMap.getNumResults() > 0;
581 }
582 
583 namespace {
584 /// Information needed to expand a generic operation to fold the reshape with
585 /// it.
586 class ExpansionInfo {
587 public:
588  // Computes the mapping from original dimensions of the op to the dimensions
589  // of the expanded op given the `indexingMap` of the fused operand/result of
590  // the generic op, the `reassocationMaps` of the reshape op and the shape of
591  // the expanded op.
592  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593  ArrayRef<AffineMap> reassociationMaps,
594  ArrayRef<OpFoldResult> expandedShape,
595  PatternRewriter &rewriter);
596  unsigned getOrigOpNumDims() const { return reassociation.size(); }
597  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598  ReassociationIndicesRef getExpandedDims(unsigned i) const {
599  return reassociation[i];
600  }
601  ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602  return expandedShapeMap[i];
603  }
604  ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605 
606 private:
607  /// Reassociation from the dimensions in the original operation to the
608  /// dimension of the expanded operation.
609  SmallVector<ReassociationIndices> reassociation;
610  /// Mapping from extent of loops in the original operation, to the extent of
611  /// loops in the expanded operation.
612  SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613  /// Extent of the loop in the original operation.
614  SmallVector<OpFoldResult> originalLoopExtent;
615  unsigned expandedOpNumDims;
616 };
617 } // namespace
618 
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620  OpOperand *fusableOpOperand,
621  ArrayRef<AffineMap> reassociationMaps,
622  ArrayRef<OpFoldResult> expandedShape,
623  PatternRewriter &rewriter) {
624  if (reassociationMaps.empty())
625  return failure();
626  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627 
628  OpBuilder::InsertionGuard g(rewriter);
629  rewriter.setInsertionPoint(linalgOp);
630  originalLoopExtent = llvm::map_to_vector(
631  linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632  [](Range r) { return r.size; });
633 
634  reassociation.clear();
635  expandedShapeMap.clear();
636  // Compute the number of dimension in the expanded op that correspond to each
637  // dimension of the original op.
638  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639  expandedShapeMap.resize(fusedIndexMap.getNumDims());
640  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643  numExpandedDims[pos] = foldedDims.getNumResults();
644  ArrayRef<OpFoldResult> shape =
645  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646  expandedShapeMap[pos].assign(shape.begin(), shape.end());
647  }
648  // The remaining dimensions remain the same.
649  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650  if (expandedShapeMap[i].empty())
651  expandedShapeMap[i] = {originalLoopExtent[i]};
652 
653  // Compute reassociation map from the original op to the expanded op.
654  unsigned sum = 0;
655  reassociation.reserve(fusedIndexMap.getNumDims());
656  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658  reassociation.emplace_back(seq.begin(), seq.end());
659  sum += numFoldedDim.value();
660  }
661  expandedOpNumDims = sum;
662  return success();
663 }
664 
665 /// Return the indexing map to use in the expanded op for a given the
666 /// `indexingMap` of the original operation.
667 static AffineMap
669  const ExpansionInfo &expansionInfo) {
670  SmallVector<AffineExpr> newExprs;
671  for (AffineExpr expr : indexingMap.getResults()) {
672  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675  return builder.getAffineDimExpr(static_cast<unsigned>(v));
676  }));
677  newExprs.append(expandedExprs.begin(), expandedExprs.end());
678  }
679  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
680  indexingMap.getNumSymbols(), newExprs,
681  builder.getContext());
682 }
683 
684 /// Return the shape and type of the operand/result to use in the expanded op
685 /// given the type in the original op.
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687 getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688  const ExpansionInfo &expansionInfo) {
689  SmallVector<OpFoldResult> expandedShape;
690  for (AffineExpr expr : indexingMap.getResults()) {
691  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692  ArrayRef<OpFoldResult> dimExpansion =
693  expansionInfo.getExpandedShapeOfDim(dim);
694  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
695  }
696  SmallVector<int64_t> expandedStaticShape;
697  std::tie(expandedStaticShape, std::ignore) =
698  decomposeMixedValues(expandedShape);
699  return {expandedShape, RankedTensorType::get(expandedStaticShape,
700  originalType.getElementType())};
701 }
702 
703 /// Returns the reassociation maps to use in the `tensor.expand_shape`
704 /// operation to convert the operands of the original operation to operands of
705 /// the expanded operation. The same method is used to compute the
706 /// `tensor.collapse_shape` used to collapse the result of the expanded
707 /// op to get the value that can replace all uses of the results of the original
708 /// op.
711  const ExpansionInfo &expansionInfo) {
712  SmallVector<ReassociationIndices> reassociation;
713  unsigned numReshapeDims = 0;
714  for (AffineExpr expr : indexingMap.getResults()) {
715  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719  reassociation.emplace_back(std::move(indices));
720  numReshapeDims += numExpandedDims;
721  }
722  return reassociation;
723 }
724 
725 /// Update the body of an expanded linalg operation having index semantics. The
726 /// indices of the original operation need to be recovered by linearizing the
727 /// indices of the correspoding dimensions of the expanded operation. For now it
728 /// is assumed that the shapes of the expanded operation needed for
729 /// linearization are static.
731  Location loc, Region &fusedRegion,
732  const ExpansionInfo &expansionInfo) {
733  // Replace the original indices by the linearization of the expanded indices.
734  for (IndexOp indexOp :
735  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736  ArrayRef<int64_t> expandedDims =
737  expansionInfo.getExpandedDims(indexOp.getDim());
738  assert(!expandedDims.empty() && "expected valid expansion info");
739 
740  // Skip index operations that are not affected by the expansion.
741  if (expandedDims.size() == 1 &&
742  expandedDims.front() == (int64_t)indexOp.getDim())
743  continue;
744 
745  // Linearize the expanded indices of the original index dimension.
746  OpBuilder::InsertionGuard guard(rewriter);
747  rewriter.setInsertionPointAfter(indexOp);
748  ArrayRef<OpFoldResult> expandedDimsShape =
749  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750  SmallVector<Value> expandedIndices;
751  expandedIndices.reserve(expandedDims.size() - 1);
752  llvm::transform(
753  expandedDims.drop_front(), std::back_inserter(expandedIndices),
754  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
755  OpFoldResult newIndex =
756  rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
757  for (auto [expandedShape, expandedIndex] :
758  llvm::zip(expandedDimsShape, expandedIndices)) {
759  AffineExpr idx, acc, shape;
760  bindDims(rewriter.getContext(), idx, acc);
761  bindSymbols(rewriter.getContext(), shape);
763  rewriter, indexOp.getLoc(), idx + acc * shape,
764  ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765  }
766  Value newIndexVal =
767  getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768  rewriter.replaceOp(indexOp, newIndexVal);
769  }
770 }
771 
772 // Create an expanded transpose op.
773 // the reassociation map is already permuted hence we inverse permute and then
774 // flatten it. Then we inverse permute it again to get the final expanded
775 // transpose permutation. For example,
776 //
777 // permutation = [2, 0, 1]
778 // reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779 //
780 // inverse permutation = [1, 2, 0]
781 // applied to reassocation_map and then flattened becomes
782 // flatened permutation = [2, 3, 4, 5, 0, 1]
783 // final permuation is the inverse of the flattened permutation.
784 //
785 // Becomes
786 //
787 // permutation=[4, 5, 0, 1, 2, 3]
788 
790  TransposeOp transposeOp,
791  Value expandedInput, Value output,
792  ExpansionInfo &expansionInfo) {
793  SmallVector<int64_t> newPerm;
794  for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795  auto reassoc = expansionInfo.getExpandedDims(perm);
796  for (int64_t dim : reassoc) {
797  newPerm.push_back(dim);
798  }
799  }
800  return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
801  output, invertPermutationVector(newPerm));
802 }
803 
804 // Create an expanded generic op.
806  PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807  ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808  ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809  // The iterator types of the expanded op are all parallel.
810  SmallVector<utils::IteratorType> iteratorTypes(
811  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812 
813  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814  for (auto j : expansionInfo.getExpandedDims(i))
815  iteratorTypes[j] = type;
816 
817  Operation *fused = rewriter.create<GenericOp>(
818  linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
819  expandedOpIndexingMaps, iteratorTypes);
820 
821  Region &fusedRegion = fused->getRegion(0);
822  Region &originalRegion = linalgOp->getRegion(0);
823  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824 
825  // Update the index accesses after the expansion.
826  updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827  expansionInfo);
828 
829  return fused;
830 }
831 
832 // Create an expanded fused op that retains the name for certain ops
833 // such as fill, copy and transpose and produce a generic op for
834 // rest of linalg ops.
835 static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836  TypeRange resultTypes,
837  ArrayRef<Value> expandedOpOperands,
838  ArrayRef<Value> outputs,
839  ArrayRef<AffineMap> expandedOpIndexingMaps,
840  ExpansionInfo &expansionInfo) {
841 
842  return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843  .Case<TransposeOp>([&](TransposeOp transposeOp) {
844  return createExpandedTransposeOp(rewriter, transposeOp,
845  expandedOpOperands[0], outputs[0],
846  expansionInfo);
847  })
848  .Case<FillOp, CopyOp>([&](Operation *op) {
849  return clone(rewriter, linalgOp, resultTypes,
850  llvm::to_vector(llvm::concat<Value>(
851  llvm::to_vector(expandedOpOperands),
852  llvm::to_vector(outputs))));
853  })
854  .Default([&](Operation *op) {
855  return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856  expandedOpOperands, outputs,
857  expansionInfo, expandedOpIndexingMaps);
858  });
859 }
860 
861 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863 /// that those conditions have been satisfied.
864 static std::optional<SmallVector<Value>>
865 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866  OpOperand *fusableOpOperand,
867  PatternRewriter &rewriter) {
868  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869  "preconditions for fuse operation failed");
870 
871  Location loc = linalgOp.getLoc();
872  SmallVector<OpFoldResult> expandedShape, collapsedShape;
873  SmallVector<AffineMap, 4> reassociationIndices;
874  Value src;
875  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876  // Try to move the dynamic dimensions in output shape before the `linalgOp`
877  // to maintain SSA validity
878  if (failed(moveValueDefinitions(
879  rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880  return std::nullopt;
881 
882  expandedShape = expandingReshapeOp.getMixedOutputShape();
883  reassociationIndices = expandingReshapeOp.getReassociationMaps();
884  src = expandingReshapeOp.getSrc();
885  } else {
886  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887  if (!collapsingReshapeOp)
888  return std::nullopt;
889 
890  expandedShape = tensor::getMixedSizes(
891  rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892  reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893  src = collapsingReshapeOp.getSrc();
894  }
895 
896  ExpansionInfo expansionInfo;
897  if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898  reassociationIndices, expandedShape,
899  rewriter)))
900  return std::nullopt;
901 
902  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
905  }));
906 
907  // Set insertion point to the generic op.
908  OpBuilder::InsertionGuard g(rewriter);
909  rewriter.setInsertionPoint(linalgOp);
910 
911  SmallVector<Value> expandedOpOperands;
912  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914  if (opOperand == fusableOpOperand) {
915  expandedOpOperands.push_back(src);
916  continue;
917  }
918  if (auto opOperandType =
919  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921  SmallVector<OpFoldResult> expandedOperandShape;
922  RankedTensorType expandedOperandType;
923  std::tie(expandedOperandShape, expandedOperandType) =
924  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925  if (expandedOperandType != opOperand->get().getType()) {
926  // Reshape the operand to get the right type.
927  SmallVector<ReassociationIndices> reassociation =
928  getReassociationForExpansion(indexingMap, expansionInfo);
930  [&](const Twine &msg) {
931  return rewriter.notifyMatchFailure(linalgOp, msg);
932  },
933  opOperandType.getShape(), expandedOperandType.getShape(),
934  reassociation,
935  /*isExpandingReshape=*/true)))
936  return std::nullopt;
937  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
938  loc, expandedOperandType, opOperand->get(), reassociation,
939  expandedOperandShape));
940  continue;
941  }
942  }
943  expandedOpOperands.push_back(opOperand->get());
944  }
945 
946  SmallVector<Value> outputs;
947  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950  SmallVector<OpFoldResult> expandedOutputShape;
951  RankedTensorType expandedOutputType;
952  std::tie(expandedOutputShape, expandedOutputType) =
953  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954  if (expandedOutputType != opOperand.get().getType()) {
955  SmallVector<ReassociationIndices> reassociation =
956  getReassociationForExpansion(indexingMap, expansionInfo);
958  [&](const Twine &msg) {
959  return rewriter.notifyMatchFailure(linalgOp, msg);
960  },
961  opOperandType.getShape(), expandedOutputType.getShape(),
962  reassociation,
963  /*isExpandingReshape=*/true)))
964  return std::nullopt;
965  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
966  loc, expandedOutputType, opOperand.get(), reassociation,
967  expandedOutputShape));
968  } else {
969  outputs.push_back(opOperand.get());
970  }
971  }
972 
973  TypeRange resultTypes = ValueRange(outputs).getTypes();
974  Operation *fusedOp =
975  createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976  outputs, expandedOpIndexingMaps, expansionInfo);
977  // Reshape the result values to their original shape if this is a collapsing
978  // reshape folded into its consumer.
979  SmallVector<Value> resultVals;
980  for (OpResult opResult : linalgOp->getOpResults()) {
981  int64_t resultNumber = opResult.getResultNumber();
982  if (resultTypes[resultNumber] != opResult.getType()) {
983  SmallVector<ReassociationIndices> reassociation =
985  linalgOp.getMatchingIndexingMap(
986  linalgOp.getDpsInitOperand(resultNumber)),
987  expansionInfo);
988  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
989  linalgOp.getLoc(), opResult.getType(),
990  fusedOp->getResult(resultNumber), reassociation));
991  } else {
992  resultVals.push_back(fusedOp->getResult(resultNumber));
993  }
994  }
995  // Assuming a single result.
996  return resultVals;
997 }
998 
999 namespace {
1000 
1001 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003 /// in the consumer is expanded.
1004 class FoldWithProducerReshapeOpByExpansion
1005  : public OpInterfaceRewritePattern<LinalgOp> {
1006 public:
1007  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008  ControlFusionFn foldReshapes,
1009  PatternBenefit benefit = 1)
1010  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011  controlFoldingReshapes(std::move(foldReshapes)) {}
1012 
1013  LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014  PatternRewriter &rewriter) const override {
1015  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016  tensor::CollapseShapeOp reshapeOp =
1017  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018  if (!reshapeOp)
1019  continue;
1020  // Fold only if
1021  // - The tensor reshape op is folding.
1022  // - All constraints of fusing with reshape by expansion are met.
1023  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024  (!controlFoldingReshapes(opOperand)))
1025  continue;
1026 
1027  std::optional<SmallVector<Value>> replacementValues =
1028  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029  if (!replacementValues)
1030  return failure();
1031  rewriter.replaceOp(linalgOp, *replacementValues);
1032  return success();
1033  }
1034  return failure();
1035  }
1036 
1037 private:
1038  ControlFusionFn controlFoldingReshapes;
1039 };
1040 
1041 class FoldPadWithProducerReshapeOpByExpansion
1042  : public OpRewritePattern<tensor::PadOp> {
1043 public:
1044  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1045  ControlFusionFn foldReshapes,
1046  PatternBenefit benefit = 1)
1047  : OpRewritePattern<tensor::PadOp>(context, benefit),
1048  controlFoldingReshapes(std::move(foldReshapes)) {}
1049 
1050  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051  PatternRewriter &rewriter) const override {
1052  tensor::CollapseShapeOp reshapeOp =
1053  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1054  if (!reshapeOp)
1055  return failure();
1056  if (!reshapeOp->hasOneUse())
1057  return failure();
1058 
1059  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1060  return rewriter.notifyMatchFailure(padOp,
1061  "fusion blocked by control function");
1062  }
1063 
1064  ArrayRef<int64_t> low = padOp.getStaticLow();
1065  ArrayRef<int64_t> high = padOp.getStaticHigh();
1066  SmallVector<ReassociationIndices> reassociations =
1067  reshapeOp.getReassociationIndices();
1068 
1069  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070  if (reInd.size() != 1 && (l != 0 || h != 0))
1071  return failure();
1072  }
1073 
1074  SmallVector<OpFoldResult> newLow, newHigh;
1075  RankedTensorType expandedType = reshapeOp.getSrcType();
1076  RankedTensorType paddedType = padOp.getResultType();
1077  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1078  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079  if (reInd.size() == 1) {
1080  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081  }
1082  for (size_t i = 0; i < reInd.size(); ++i) {
1083  newLow.push_back(padOp.getMixedLowPad()[idx]);
1084  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1085  }
1086  }
1087 
1088  Location loc = padOp->getLoc();
1089  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090  auto newPadOp = rewriter.create<tensor::PadOp>(
1091  loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092  padOp.getConstantPaddingValue(), padOp.getNofold());
1093 
1094  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1096 
1097  return success();
1098  }
1099 
1100 private:
1101  ControlFusionFn controlFoldingReshapes;
1102 };
1103 
1104 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1105 /// by expanding the dimensionality of the loop in the producer op.
1106 struct FoldReshapeWithGenericOpByExpansion
1107  : public OpRewritePattern<tensor::ExpandShapeOp> {
1108 
1109  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1110  ControlFusionFn foldReshapes,
1111  PatternBenefit benefit = 1)
1112  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113  controlFoldingReshapes(std::move(foldReshapes)) {}
1114 
1115  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1116  PatternRewriter &rewriter) const override {
1117  // Fold only if all constraints of fusing with reshape by expansion are met.
1118  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119  if (!producerResult) {
1120  return rewriter.notifyMatchFailure(reshapeOp,
1121  "source not produced by an operation");
1122  }
1123 
1124  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1125  if (!producer) {
1126  return rewriter.notifyMatchFailure(reshapeOp,
1127  "producer not a generic op");
1128  }
1129 
1131  producer,
1132  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1133  return rewriter.notifyMatchFailure(
1134  reshapeOp, "failed preconditions of fusion with producer generic op");
1135  }
1136 
1137  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1138  return rewriter.notifyMatchFailure(reshapeOp,
1139  "fusion blocked by control function");
1140  }
1141 
1142  std::optional<SmallVector<Value>> replacementValues =
1144  producer, reshapeOp,
1145  producer.getDpsInitOperand(producerResult.getResultNumber()),
1146  rewriter);
1147  if (!replacementValues) {
1148  return rewriter.notifyMatchFailure(reshapeOp,
1149  "fusion by expansion failed");
1150  }
1151 
1152  // Find the replacement for the reshape op. Since the replacements have the
1153  // same type as the returns of the original generic op, the consumer reshape
1154  // op can be replaced by the source of the collapse_shape op that defines
1155  // the replacement.
1156  Value reshapeReplacement =
1157  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158  .getResultNumber()];
1159  if (auto collapseOp =
1160  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1161  reshapeReplacement = collapseOp.getSrc();
1162  }
1163  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1164  rewriter.replaceOp(producer, *replacementValues);
1165  return success();
1166  }
1167 
1168 private:
1169  ControlFusionFn controlFoldingReshapes;
1170 };
1171 } // namespace
1172 
1173 //===---------------------------------------------------------------------===//
1174 // Methods and patterns to fuse reshape with linalg.generic operations by
1175 // contraction of dimensions.
1176 //===---------------------------------------------------------------------===//
1177 
1178 /// For a given list of indices in the range of the `indexingMap` that are
1179 /// folded, return the indices of the corresponding domain. Return
1180 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1181 /// reassociation are distinct.
1182 static ReassociationIndices
1184  ReassociationIndicesRef rangeReassociation) {
1185  assert(indexingMap.isProjectedPermutation() &&
1186  "expected projected permutation");
1187 
1188  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1189  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1191  }));
1192  // The projected permutation semantics ensures that there is no repetition of
1193  // the domain indices.
1194  return domainReassociation;
1195 }
1196 
1197 /// For a given `dimSequence`, check if the sequence is conserved in the
1198 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1199 /// Non-existence of the sequence returns true as well.
1201  ReassociationIndicesRef dimSequence) {
1202  assert(!dimSequence.empty() &&
1203  "expected non-empty list for dimension sequence");
1204  assert(indexingMap.isProjectedPermutation() &&
1205  "expected indexing map to be projected permutation");
1206 
1207  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208  sequenceElements.insert_range(dimSequence);
1209 
1210  unsigned dimSequenceStart = dimSequence[0];
1211  for (const auto &expr : enumerate(indexingMap.getResults())) {
1212  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1213  // 1. Check if this start of the sequence.
1214  if (dimInMapStart == dimSequenceStart) {
1215  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1216  return false;
1217  // 1a. Check if sequence is preserved.
1218  for (const auto &dimInSequence : enumerate(dimSequence)) {
1219  unsigned dimInMap =
1220  cast<AffineDimExpr>(
1221  indexingMap.getResult(expr.index() + dimInSequence.index()))
1222  .getPosition();
1223  if (dimInMap != dimInSequence.value())
1224  return false;
1225  }
1226  // Found the sequence. Projected permutation
1227  // enforces that all AffineDimExprs in the result are unique, so no
1228  // further checks are needed.
1229  return true;
1230  }
1231  // 2. If position in the expr (which is of type AffineDimExpr) is part
1232  // of sequence, return false here. This implies the entire sequence does not
1233  // exist in the indexing map.
1234  if (sequenceElements.count(dimInMapStart))
1235  return false;
1236  }
1237  // 3. No element of sequence found. Return true.
1238  return true;
1239 }
1240 
1243  return llvm::all_of(maps, [&](AffineMap map) {
1244  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1245  return isDimSequencePreserved(map, dimSequence);
1246  });
1247  });
1248 }
1249 
1250 // Return the list of dimensions of the iteration domain that can be
1251 // collapsed to allow for fusion with the a producer that is an expand_shape
1252 // operation. If all dimensions created by expansion can be collapsed in the
1253 // iteration space then the reshape is defunct.
1254 //
1255 // Example:
1256 //
1257 // ```mlir
1258 // #map = affine_map<(d0, d1) -> (d0, d1)>
1259 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1260 // %2 = tensor.empty [..] : tensor<?x4xf32>
1261 // %3 = linalg.generic {
1262 // indexing_maps = [#map, #map],
1263 // iterator_types = ["parallel" ,"parallel"]}
1264 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1265 // ```
1266 //
1267 // can be fused by collapsing the dimensions of the iteration space.
1268 //
1269 // ```mlir
1270 // #map = affine_map<(d0) -> (d0)>
1271 // %2 = tensor.empty [..] : tensor<?xf32>
1272 // %3 = linalg.generic {
1273 // indexing_maps = [#map, #map],
1274 // iterator_types = ["parallel"]}
1275 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1276 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1277 // ```
1278 //
1279 // In the following example,
1280 //
1281 // ```mlir
1282 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1283 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1284 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1285 // %2 = tensor.empty [..] : tensor<4x?xf32>
1286 // %2 = linalg.generic {
1287 // indexing_maps = [#map0, #map1],
1288 // iterator_types = ["parallel" ,"parallel"]}
1289 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1290 // ```
1291 //
1292 // the reshape cannot be fused with the generic op by collapsing the op
1293 // dimensions since the indexing maps will have to contain mods and divs
1294 // to preserve the accesses pattern. When no dimensions of the iteration
1295 // space are collapsable and empty vector is returned.
1297 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1298  ArrayRef<ReassociationIndices> reassociation) {
1299  // Some basic checks for this fusion to be valid.
1300  if (!genericOp.hasPureTensorSemantics())
1301  return {};
1302 
1303  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1304  return map.isProjectedPermutation();
1305  })) {
1306  return {};
1307  }
1308 
1309  // Compute all the loops with the reduction iterator types.
1310  SmallVector<unsigned> reductionDims;
1311  genericOp.getReductionDims(reductionDims);
1312 
1313  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315  auto iteratorTypes = genericOp.getIteratorTypesArray();
1316  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1317  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1318  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1319 
1320  // Ignore dims that are not folded.
1321  if (foldedRangeDims.size() == 1)
1322  continue;
1323 
1324  ReassociationIndices foldedIterationSpaceDims =
1325  getDomainReassociation(indexingMap, foldedRangeDims);
1326 
1327  // Check that the folded iteration dims do not contain already processed
1328  // dims.
1329  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330  return processedIterationDims.count(dim);
1331  }))
1332  continue;
1333 
1334  // Check that all folded iterator types are all parallel or all reductions.
1335  utils::IteratorType startIteratorType =
1336  iteratorTypes[foldedIterationSpaceDims[0]];
1337  if (!isParallelIterator(startIteratorType) &&
1338  !isReductionIterator(startIteratorType))
1339  continue;
1340  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341  return iteratorTypes[dim] != startIteratorType;
1342  }))
1343  continue;
1344 
1345  // If the folded dimensions correspond to a "reduction" iterator type,
1346  // the folded dimensions need to be "in-order". Strictly speaking this is
1347  // not necessary, for reductions that are associative and commutative, but
1348  // using a more strict definition of reduction for now.
1349  if (isReductionIterator(startIteratorType)) {
1350  bool isContiguous = false;
1351  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1352  // Move window in `reductionDims` to start of the folded iteration dims.
1353  if (startDim.value() != foldedIterationSpaceDims[0])
1354  continue;
1355  // If sizes doesnt match, trivial not contiguous. This condition should
1356  // not be hit.
1357  if (startDim.index() + foldedIterationSpaceDims.size() >
1358  reductionDims.size())
1359  break;
1360  // Check that the contiguity is maintained.
1361  isContiguous = true;
1362  for (const auto &foldedDim :
1363  llvm::enumerate(foldedIterationSpaceDims)) {
1364  if (reductionDims[foldedDim.index() + startDim.index()] !=
1365  foldedDim.value()) {
1366  isContiguous = false;
1367  break;
1368  }
1369  }
1370  break;
1371  }
1372  if (!isContiguous)
1373  continue;
1374  }
1375 
1376  // Check that the sequence is preserved in all indexing maps.
1377  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1378  [&](AffineMap indexingMap) {
1379  return !isDimSequencePreserved(indexingMap,
1380  foldedIterationSpaceDims);
1381  }))
1382  continue;
1383 
1384  processedIterationDims.insert_range(foldedIterationSpaceDims);
1385  iterationSpaceReassociation.emplace_back(
1386  std::move(foldedIterationSpaceDims));
1387  }
1388 
1389  return iterationSpaceReassociation;
1390 }
1391 
1392 /// Helper class to carry state while collapsing the `linalg.generic` op.
1393 namespace {
1394 class CollapsingInfo {
1395 public:
1396  LogicalResult initialize(unsigned origNumLoops,
1397  ArrayRef<ReassociationIndices> foldedIterationDims) {
1398  llvm::SmallDenseSet<int64_t, 4> processedDims;
1399  // Find all the dims that are folded.
1400  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1401  if (foldedIterationDim.empty())
1402  continue;
1403  // If the folded dims contain dims already folded, that's illegal
1404  // specification. Repetition within a list is also illegal.
1405  for (auto dim : foldedIterationDim) {
1406  if (dim >= origNumLoops)
1407  return failure();
1408  if (processedDims.count(dim))
1409  return failure();
1410  processedDims.insert(dim);
1411  }
1412  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413  foldedIterationDim.end());
1414  }
1415  if (processedDims.size() > origNumLoops)
1416  return failure();
1417 
1418  // Add all the preserved dims of the original op as single
1419  // elements to `collapsedOpToOrigOpIterationDim`.
1420  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421  if (processedDims.count(dim))
1422  continue;
1423  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1424  }
1425 
1426  llvm::sort(collapsedOpToOrigOpIterationDim,
1428  return lhs[0] < rhs[0];
1429  });
1430  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431  for (const auto &foldedDims :
1432  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1433  for (const auto &dim : enumerate(foldedDims.value()))
1434  origOpToCollapsedOpIterationDim[dim.value()] =
1435  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1436  }
1437  return success();
1438  }
1439 
1440  /// Return mapping from collapsed loop domain to original loop domain.
1441  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1442  return collapsedOpToOrigOpIterationDim;
1443  }
1444 
1445  /// Return mapping from original loop domain to collapsed loop domain. The
1446  /// mapping is a pair. First value is the dimension in the collapsed loop that
1447  /// the original loop is mapped to. Second is the relative position in folded
1448  /// list of this domain. For example if the original loop domain is 3D, and
1449  /// the collapsed loop domain is folding all of it, i.e.
1450  ///
1451  /// ```
1452  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1453  /// ```
1454  ///
1455  /// then
1456  ///
1457  /// ```
1458  /// origOpToCollapsedOpMapping[0] = {0, 0};
1459  /// origOpToCollapsedOpMapping[1] = {0, 1};
1460  /// origOpToCollapsedOpMapping[2] = {0, 2};
1461  /// origOpToCollapsedOpMapping[3] = {1, 0};
1462  /// origOpToCollapsedOpMapping[4] = {1, 1};
1463  /// ```
1464  ///
1465  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1466  return origOpToCollapsedOpIterationDim;
1467  }
1468 
1469  /// Return the collapsed op iteration domain rank.
1470  unsigned getCollapsedOpIterationRank() const {
1471  return collapsedOpToOrigOpIterationDim.size();
1472  }
1473 
1474 private:
1475  /// Map from the iteration domain index in collapsed op to the iteration
1476  /// domain indices in the original op.
1477  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1478 
1479  /// Map from iteration domain index in the original op to the iteration domain
1480  /// index in the collapsed op.
1481  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1482 };
1483 } // namespace
1484 
1485 /// Get the iterator types for the collapsed operation given the original
1486 /// iterator types and collapsed dimensions.
1489  const CollapsingInfo &collapsingInfo) {
1490  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1491  for (ReassociationIndicesRef foldedIterDims :
1492  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493  assert(!foldedIterDims.empty() &&
1494  "reassociation indices expected to have non-empty sets");
1495  // Just pick the iterator type of the first folded dim. Pre-condition checks
1496  // expected to have checked that iterator types of all folded dimensions are
1497  // the same.
1498  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1499  }
1500  return collapsedIteratorTypes;
1501 }
1502 
1503 /// Compute the indexing map in the collapsed op that corresponds to the given
1504 /// `indexingMap` of the original operation.
1505 static AffineMap
1507  const CollapsingInfo &collapsingInfo) {
1508  MLIRContext *context = indexingMap.getContext();
1509  assert(indexingMap.isProjectedPermutation() &&
1510  "expected indexing map to be projected permutation");
1511  SmallVector<AffineExpr> resultExprs;
1512  auto origOpToCollapsedOpMapping =
1513  collapsingInfo.getOrigOpToCollapsedOpMapping();
1514  for (auto expr : indexingMap.getResults()) {
1515  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1516  // If the dim is not the first of the collapsed dim, do nothing.
1517  if (origOpToCollapsedOpMapping[dim].second != 0)
1518  continue;
1519  // The next n-dims are guaranteed to be collapsed. So just use the
1520  // iteration dimension of the collapsed op.
1521  resultExprs.push_back(
1522  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1523  }
1524  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525  resultExprs, context);
1526 }
1527 
1528 /// Return the `reassociation` indices to use to collapse the operand when the
1529 /// iteration space of a generic op is collapsed.
1532  const CollapsingInfo &collapsingInfo) {
1533  unsigned counter = 0;
1534  SmallVector<ReassociationIndices> operandReassociation;
1535  auto origOpToCollapsedOpMapping =
1536  collapsingInfo.getOrigOpToCollapsedOpMapping();
1537  auto collapsedOpToOrigOpMapping =
1538  collapsingInfo.getCollapsedOpToOrigOpMapping();
1539  while (counter < indexingMap.getNumResults()) {
1540  unsigned dim =
1541  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1542  // This is the start of a collapsed dimensions of the iteration that
1543  // is gauranteed to be preserved in the indexing map. The number of folded
1544  // dims is obtained from the collapsed op to original op mapping.
1545  unsigned numFoldedDims =
1546  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1547  .size();
1548  if (origOpToCollapsedOpMapping[dim].second == 0) {
1549  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550  operandReassociation.emplace_back(range.begin(), range.end());
1551  }
1552  counter += numFoldedDims;
1553  }
1554  return operandReassociation;
1555 }
1556 
1557 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1558 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1559  OpOperand *opOperand,
1560  const CollapsingInfo &collapsingInfo,
1561  OpBuilder &builder) {
1562  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1563  SmallVector<ReassociationIndices> operandReassociation =
1564  getOperandReassociation(indexingMap, collapsingInfo);
1565 
1566  // If the number of entries in the reassociation for the operand is same as
1567  // the number of results of the indexing map, then nothing to do for this
1568  // operand.
1569  Value operand = opOperand->get();
1570  if (operandReassociation.size() == indexingMap.getNumResults())
1571  return operand;
1572 
1573  // Insert a reshape to collapse the dimensions.
1574  if (isa<MemRefType>(operand.getType())) {
1575  return builder
1576  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1577  .getResult();
1578  }
1579  return builder
1580  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1581  .getResult();
1582 }
1583 
1584 /// Modify the `linalg.index` operations in the original generic op, to its
1585 /// value in the collapsed operation.
1587  Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1588  ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1589  OpBuilder::InsertionGuard g(rewriter);
1590  rewriter.setInsertionPointToStart(block);
1591 
1592  // Collect all the original index ops.
1593  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1594 
1595  // For each folded dimension list resolve the original induction variable
1596  // values in terms of the folded dimension induction variable.
1597  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1598  // can be inverted to
1599  // i2 = i_{folded} % d2
1600  // i1 = (i_{folded} / d2) % d1
1601  // i0 = i_{folded} / (d1 * d2)
1602  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1603  for (auto foldedDims :
1604  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1605  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1606  Value newIndexVal =
1607  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1608  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1609  Value loopDim =
1610  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1611  indexReplacementVals[dim] =
1612  rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1613  newIndexVal =
1614  rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1615  }
1616  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1617  }
1618 
1619  for (auto indexOp : indexOps) {
1620  auto dim = indexOp.getDim();
1621  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1622  }
1623 }
1624 
1626  const CollapsingInfo &collapsingInfo,
1627  RewriterBase &rewriter,
1628  SmallVectorImpl<Value> &inputOperands,
1629  SmallVectorImpl<Value> &outputOperands,
1630  SmallVectorImpl<Type> &resultTypes) {
1631  Location loc = op->getLoc();
1632  inputOperands =
1633  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1634  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1635  rewriter);
1636  });
1637 
1638  // Get the output operands and result types.
1639  resultTypes.reserve(op.getNumDpsInits());
1640  outputOperands.reserve(op.getNumDpsInits());
1641  for (OpOperand &output : op.getDpsInitsMutable()) {
1642  Value newOutput =
1643  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1644  outputOperands.push_back(newOutput);
1645  // If the op has "buffer semantics", then the init operands are ranked
1646  // memrefs and the op has no results.
1647  if (!op.hasPureBufferSemantics())
1648  resultTypes.push_back(newOutput.getType());
1649  }
1650 }
1651 
1652 /// Clone a `LinalgOp` to a collapsed version of same name
1653 template <typename OpTy>
1654 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1655  const CollapsingInfo &collapsingInfo) {
1656  return nullptr;
1657 }
1658 
1659 /// Collapse any `LinalgOp` that does not require any specialization such as
1660 /// indexing_maps, iterator_types, etc.
1661 template <>
1662 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1663  const CollapsingInfo &collapsingInfo) {
1664  SmallVector<Value> inputOperands, outputOperands;
1665  SmallVector<Type> resultTypes;
1666  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1667  outputOperands, resultTypes);
1668 
1669  return clone(
1670  rewriter, origOp, resultTypes,
1671  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1672 }
1673 
1674 /// Collapse a `GenericOp`
1675 template <>
1677  GenericOp origOp,
1678  const CollapsingInfo &collapsingInfo) {
1679  SmallVector<Value> inputOperands, outputOperands;
1680  SmallVector<Type> resultTypes;
1681  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1682  outputOperands, resultTypes);
1683  SmallVector<AffineMap> indexingMaps(
1684  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1685  return getCollapsedOpIndexingMap(map, collapsingInfo);
1686  }));
1687 
1689  origOp.getIteratorTypesArray(), collapsingInfo));
1690 
1691  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1692  origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1693  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1694  Block *origOpBlock = &origOp->getRegion(0).front();
1695  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1696  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1697  collapsedOpBlock->getArguments());
1698  return collapsedOp;
1699 }
1700 
1701 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1702  RewriterBase &rewriter) {
1703  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1704  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1705  } else {
1706  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1707  }
1708 }
1709 
1710 /// Implementation of fusion with reshape operation by collapsing dimensions.
1711 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1712  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1713  RewriterBase &rewriter) {
1714  // Bail on trivial no-op cases.
1715  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1716  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1717  return foldedDims.size() <= 1;
1718  }))
1719  return failure();
1720 
1721  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1722  if (hasPureBufferSemantics &&
1723  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1724  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1725  if (!memRefToCollapse)
1726  return true;
1727 
1728  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1729  memRefToCollapse, foldedIterationDims);
1730  }))
1731  return rewriter.notifyMatchFailure(op,
1732  "memref is not guaranteed collapsible");
1733 
1734  CollapsingInfo collapsingInfo;
1735  if (failed(
1736  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1737  return rewriter.notifyMatchFailure(
1738  op, "illegal to collapse specified dimensions");
1739  }
1740 
1741  // Bail on non-canonical ranges.
1742  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1743  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1744  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1745  return cast<IntegerAttr>(attr).getInt() == value;
1746  llvm::APInt actual;
1747  return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1748  actual.getSExtValue() == value;
1749  };
1750  if (!llvm::all_of(loopRanges, [&](Range range) {
1751  return opFoldIsConstantValue(range.offset, 0) &&
1752  opFoldIsConstantValue(range.stride, 1);
1753  })) {
1754  return rewriter.notifyMatchFailure(
1755  op, "expected all loop ranges to have zero start and unit stride");
1756  }
1757 
1758  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1759 
1760  Location loc = op->getLoc();
1761  SmallVector<OpFoldResult> loopBound =
1762  llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1763 
1764  if (collapsedOp.hasIndexSemantics()) {
1765  // Collect the loop range of the generic op.
1766  OpBuilder::InsertionGuard g(rewriter);
1767  rewriter.setInsertionPoint(collapsedOp);
1768  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1769  collapsingInfo, loopBound, rewriter);
1770  }
1771 
1772  // Insert expanding reshape for the result to get back the original result
1773  // type.
1774  SmallVector<Value> results;
1775  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1776  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1777  auto originalResultType =
1778  cast<ShapedType>(originalResult.value().getType());
1779  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1780  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1781  AffineMap indexingMap =
1782  op.getIndexingMapMatchingResult(originalResult.value());
1783  SmallVector<ReassociationIndices> reassociation =
1784  getOperandReassociation(indexingMap, collapsingInfo);
1785  assert(
1786  indexingMap.isProjectedPermutation() &&
1787  "Expected indexing map to be a projected permutation for collapsing");
1788  SmallVector<OpFoldResult> resultShape =
1789  applyPermutationMap(indexingMap, ArrayRef(loopBound));
1790  Value result;
1791  if (isa<MemRefType>(collapsedOpResult.getType())) {
1792  MemRefType expandShapeResultType = MemRefType::get(
1793  originalResultType.getShape(), originalResultType.getElementType());
1794  result = rewriter.create<memref::ExpandShapeOp>(
1795  loc, expandShapeResultType, collapsedOpResult, reassociation,
1796  resultShape);
1797  } else {
1798  result = rewriter.create<tensor::ExpandShapeOp>(
1799  loc, originalResultType, collapsedOpResult, reassociation,
1800  resultShape);
1801  }
1802  results.push_back(result);
1803  } else {
1804  results.push_back(collapsedOpResult);
1805  }
1806  }
1807  return CollapseResult{results, collapsedOp};
1808 }
1809 
1810 namespace {
1811 
1812 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1813 /// contracting dimensions of the loop.
1814 class FoldWithProducerReshapeOpByCollapsing
1815  : public OpRewritePattern<GenericOp> {
1816 public:
1817  // TODO : support fusion with all linalg ops, not just generic.
1818  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1819  ControlFusionFn foldReshapes,
1820  PatternBenefit benefit = 1)
1821  : OpRewritePattern<GenericOp>(context, benefit),
1822  controlFoldingReshapes(std::move(foldReshapes)) {}
1823 
1824  LogicalResult matchAndRewrite(GenericOp genericOp,
1825  PatternRewriter &rewriter) const override {
1826  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1827  tensor::ExpandShapeOp reshapeOp =
1828  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1829  if (!reshapeOp)
1830  continue;
1831 
1832  SmallVector<ReassociationIndices> collapsableIterationDims =
1833  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1834  reshapeOp.getReassociationIndices());
1835  if (collapsableIterationDims.empty() ||
1836  !controlFoldingReshapes(&opOperand)) {
1837  continue;
1838  }
1839 
1840  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1841  genericOp, collapsableIterationDims, rewriter);
1842  if (!collapseResult) {
1843  return rewriter.notifyMatchFailure(
1844  genericOp, "failed to do the fusion by collapsing transformation");
1845  }
1846 
1847  rewriter.replaceOp(genericOp, collapseResult->results);
1848  return success();
1849  }
1850  return failure();
1851  }
1852 
1853 private:
1854  ControlFusionFn controlFoldingReshapes;
1855 };
1856 
1857 /// Pattern to fold a tensor.collapse_shape op with its producer generic op
1858 /// by expanding the dimensionality of the loop in the producer op.
1859 struct FoldReshapeWithGenericOpByCollapsing
1860  : public OpRewritePattern<tensor::CollapseShapeOp> {
1861 
1862  FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1863  ControlFusionFn foldReshapes,
1864  PatternBenefit benefit = 1)
1865  : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1866  controlFoldingReshapes(std::move(foldReshapes)) {}
1867 
1868  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1869  PatternRewriter &rewriter) const override {
1870  // Fold only if all constraints of fusing with reshape by collapsing are
1871  // met.
1872  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1873  if (!producerResult) {
1874  return rewriter.notifyMatchFailure(reshapeOp,
1875  "source not produced by an operation");
1876  }
1877 
1878  // TODO : support fusion with all linalg producers, not just generic.
1879  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1880  if (!producer) {
1881  return rewriter.notifyMatchFailure(reshapeOp,
1882  "producer not a generic op");
1883  }
1884 
1885  SmallVector<ReassociationIndices> collapsableIterationDims =
1887  producer,
1888  producer.getDpsInitOperand(producerResult.getResultNumber()),
1889  reshapeOp.getReassociationIndices());
1890  if (collapsableIterationDims.empty()) {
1891  return rewriter.notifyMatchFailure(
1892  reshapeOp, "failed preconditions of fusion with producer generic op");
1893  }
1894 
1895  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1896  return rewriter.notifyMatchFailure(reshapeOp,
1897  "fusion blocked by control function");
1898  }
1899 
1900  // Set the insertion point after `producer` because there could be uses
1901  // of `producer` between it and the `tensor.collapse_shape` op.
1902  rewriter.setInsertionPointAfter(producer);
1903  std::optional<CollapseResult> collapseResult =
1904  collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1905  if (!collapseResult) {
1906  return rewriter.notifyMatchFailure(
1907  producer, "failed to do the fusion by collapsing transformation");
1908  }
1909 
1910  rewriter.replaceOp(producer, collapseResult->results);
1911  return success();
1912  }
1913 
1914 private:
1915  ControlFusionFn controlFoldingReshapes;
1916 };
1917 
1918 class FoldPadWithProducerReshapeOpByCollapsing
1919  : public OpRewritePattern<tensor::PadOp> {
1920 public:
1921  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1922  ControlFusionFn foldReshapes,
1923  PatternBenefit benefit = 1)
1924  : OpRewritePattern<tensor::PadOp>(context, benefit),
1925  controlFoldingReshapes(std::move(foldReshapes)) {}
1926 
1927  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1928  PatternRewriter &rewriter) const override {
1929  tensor::ExpandShapeOp reshapeOp =
1930  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1931  if (!reshapeOp)
1932  return failure();
1933  if (!reshapeOp->hasOneUse())
1934  return failure();
1935 
1936  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1937  return rewriter.notifyMatchFailure(padOp,
1938  "fusion blocked by control function");
1939  }
1940 
1941  ArrayRef<int64_t> low = padOp.getStaticLow();
1942  ArrayRef<int64_t> high = padOp.getStaticHigh();
1943  SmallVector<ReassociationIndices> reassociations =
1944  reshapeOp.getReassociationIndices();
1945 
1946  for (auto reInd : reassociations) {
1947  if (reInd.size() == 1)
1948  continue;
1949  if (llvm::any_of(reInd, [&](int64_t ind) {
1950  return low[ind] != 0 || high[ind] != 0;
1951  })) {
1952  return failure();
1953  }
1954  }
1955 
1956  SmallVector<OpFoldResult> newLow, newHigh;
1957  RankedTensorType collapsedType = reshapeOp.getSrcType();
1958  RankedTensorType paddedType = padOp.getResultType();
1959  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1960  SmallVector<OpFoldResult> expandedPaddedSizes(
1961  getMixedValues(reshapeOp.getStaticOutputShape(),
1962  reshapeOp.getOutputShape(), rewriter));
1963  AffineExpr d0, d1, d2;
1964  bindDims(rewriter.getContext(), d0, d1, d2);
1965  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1966  Location loc = reshapeOp->getLoc();
1967  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1968  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1969  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1970  if (reInd.size() == 1) {
1971  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1973  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1974  expandedPaddedSizes[reInd[0]] = paddedSize;
1975  }
1976  newLow.push_back(l);
1977  newHigh.push_back(h);
1978  }
1979 
1980  RankedTensorType collapsedPaddedType =
1981  paddedType.clone(collapsedPaddedShape);
1982  auto newPadOp = rewriter.create<tensor::PadOp>(
1983  loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1984  padOp.getConstantPaddingValue(), padOp.getNofold());
1985 
1986  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1987  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1988  expandedPaddedSizes);
1989 
1990  return success();
1991  }
1992 
1993 private:
1994  ControlFusionFn controlFoldingReshapes;
1995 };
1996 
1997 /// Pattern to collapse dimensions.
1998 template <typename LinalgType>
1999 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2000 public:
2001  CollapseLinalgDimensions(MLIRContext *context,
2002  GetCollapsableDimensionsFn collapseDimensions,
2003  PatternBenefit benefit = 1)
2004  : OpRewritePattern<LinalgType>(context, benefit),
2005  controlCollapseDimension(std::move(collapseDimensions)) {}
2006 
2007  LogicalResult matchAndRewrite(LinalgType op,
2008  PatternRewriter &rewriter) const override {
2009  SmallVector<ReassociationIndices> collapsableIterationDims =
2010  controlCollapseDimension(op);
2011  if (collapsableIterationDims.empty())
2012  return failure();
2013 
2014  // Check if the specified list of dimensions to collapse is a valid list.
2015  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2016  collapsableIterationDims)) {
2017  return rewriter.notifyMatchFailure(
2018  op, "specified dimensions cannot be collapsed");
2019  }
2020 
2021  std::optional<CollapseResult> collapseResult =
2022  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2023  if (!collapseResult) {
2024  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2025  }
2026  rewriter.replaceOp(op, collapseResult->results);
2027  return success();
2028  }
2029 
2030 private:
2031  GetCollapsableDimensionsFn controlCollapseDimension;
2032 };
2033 
2034 } // namespace
2035 
2036 //===---------------------------------------------------------------------===//
2037 // Methods and patterns that fuse constants with linalg.generic operations.
2038 //===---------------------------------------------------------------------===//
2039 
2040 namespace {
2041 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2042 /// handle cases where the constant is not single-valued.
2043 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2044 public:
2045  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2046  : OpRewritePattern<GenericOp>(context, benefit) {}
2047 
2048  LogicalResult matchAndRewrite(GenericOp genericOp,
2049  PatternRewriter &rewriter) const override {
2050  if (!genericOp.hasPureTensorSemantics())
2051  return failure();
2052  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2053  Operation *def = opOperand->get().getDefiningOp();
2054  TypedAttr constantAttr;
2055  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2056  {
2057  DenseElementsAttr splatAttr;
2058  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2059  splatAttr.isSplat() &&
2060  splatAttr.getType().getElementType().isIntOrFloat()) {
2061  constantAttr = splatAttr.getSplatValue<TypedAttr>();
2062  return true;
2063  }
2064  }
2065  {
2066  IntegerAttr intAttr;
2067  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2068  constantAttr = intAttr;
2069  return true;
2070  }
2071  }
2072  {
2073  FloatAttr floatAttr;
2074  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2075  constantAttr = floatAttr;
2076  return true;
2077  }
2078  }
2079  return false;
2080  };
2081 
2082  auto resultValue = dyn_cast<OpResult>(opOperand->get());
2083  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2084  continue;
2085 
2086  // The operands and the indexing_maps of the fused operation the same as
2087  // the operands and indexing_maps of the generic operations with the
2088  // values at the constant index dropped.
2089  SmallVector<AffineMap> fusedIndexMaps;
2090  SmallVector<Value> fusedOperands;
2091  SmallVector<Location> fusedLocs{genericOp.getLoc()};
2092  fusedIndexMaps.reserve(genericOp->getNumOperands());
2093  fusedOperands.reserve(genericOp.getNumDpsInputs());
2094  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2095  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2096  if (inputOperand == opOperand)
2097  continue;
2098  Value inputValue = inputOperand->get();
2099  fusedIndexMaps.push_back(
2100  genericOp.getMatchingIndexingMap(inputOperand));
2101  fusedOperands.push_back(inputValue);
2102  fusedLocs.push_back(inputValue.getLoc());
2103  }
2104  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2105  fusedIndexMaps.push_back(
2106  genericOp.getMatchingIndexingMap(&outputOperand));
2107 
2108  // Check if the operation shapes to loops map is computable.
2109  if (!inversePermutation(
2110  concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2111  return rewriter.notifyMatchFailure(
2112  genericOp, "fused op loop bound computation failed");
2113  }
2114 
2115  // Create a constant scalar value from the splat constant.
2116  Value scalarConstant =
2117  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2118 
2119  SmallVector<Value> outputOperands = genericOp.getOutputs();
2120  auto fusedOp = rewriter.create<GenericOp>(
2121  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2122  /*inputs=*/fusedOperands,
2123  /*outputs=*/outputOperands,
2124  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2125  genericOp.getIteratorTypes(),
2126  /*doc=*/nullptr,
2127  /*library_call=*/nullptr);
2128 
2129  // Map the block argument corresponding to the replaced argument with the
2130  // scalar constant.
2131  Region &region = genericOp->getRegion(0);
2132  Block &entryBlock = *region.begin();
2133  IRMapping mapping;
2134  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2135  scalarConstant);
2136  Region &fusedRegion = fusedOp->getRegion(0);
2137  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2138  mapping);
2139  rewriter.replaceOp(genericOp, fusedOp->getResults());
2140  return success();
2141  }
2142  return failure();
2143  }
2144 };
2145 
2146 } // namespace
2147 
2148 //===---------------------------------------------------------------------===//
2149 // Miscellaneous patterns that help fusion.
2150 //===---------------------------------------------------------------------===//
2151 
2152 namespace {
2153 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2154 /// value of the `outs` operand is not used within the op. This is only
2155 /// implemented for `linalg.generic` operations for now, but should hold for all
2156 /// linalg structured ops.
2157 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2159 
2160  LogicalResult matchAndRewrite(GenericOp op,
2161  PatternRewriter &rewriter) const override {
2162  rewriter.startOpModification(op);
2163  bool modifiedOutput = false;
2164  Location loc = op.getLoc();
2165  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2166  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2167  Value operandVal = opOperand.get();
2168  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2169  if (!operandType)
2170  continue;
2171 
2172  // If outs is sparse, leave it to the sparsifier.
2174  continue;
2175 
2176  // If outs is already an `empty` operation, nothing to do.
2177  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2178  if (definingOp)
2179  continue;
2180  modifiedOutput = true;
2181  SmallVector<OpFoldResult> mixedSizes =
2182  tensor::getMixedSizes(rewriter, loc, operandVal);
2183  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2184  loc, mixedSizes, operandType.getElementType());
2185  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2186  }
2187  }
2188  if (!modifiedOutput) {
2189  rewriter.cancelOpModification(op);
2190  return failure();
2191  }
2192  rewriter.finalizeOpModification(op);
2193  return success();
2194  }
2195 };
2196 
2197 /// Fold linalg.fill into linalg.generic
2198 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2200 
2201  LogicalResult matchAndRewrite(GenericOp genericOp,
2202  PatternRewriter &rewriter) const override {
2203  if (!genericOp.hasPureTensorSemantics())
2204  return failure();
2205  bool fillFound = false;
2206  Block &payload = genericOp.getRegion().front();
2207  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2208  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2209  continue;
2210  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2211  if (!fillOp)
2212  continue;
2213  fillFound = true;
2214  Value fillVal = fillOp.value();
2215  auto resultType =
2216  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2217  Value convertedVal =
2218  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2219  /*isUnsignedCast =*/false);
2220  rewriter.replaceAllUsesWith(
2221  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2222  }
2223  return success(fillFound);
2224  }
2225 };
2226 } // namespace
2227 
2230  const ControlFusionFn &controlFoldingReshapes) {
2231  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2232  controlFoldingReshapes);
2233  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2234  controlFoldingReshapes);
2235  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2236  controlFoldingReshapes);
2237 }
2238 
2241  const ControlFusionFn &controlFoldingReshapes) {
2242  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2243  controlFoldingReshapes);
2244  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2245  patterns.getContext(), controlFoldingReshapes);
2246  patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2247  controlFoldingReshapes);
2248 }
2249 
2252  const ControlFusionFn &controlElementwiseOpsFusion) {
2253  auto *context = patterns.getContext();
2254  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2255  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2256  RemoveOutsDependency>(context);
2257  // Add the patterns that clean up dead operands and results.
2259 }
2260 
2263  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2264  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2265  CollapseLinalgDimensions<linalg::CopyOp>>(
2266  patterns.getContext(), controlCollapseDimensions);
2267 }
2268 
2269 //===---------------------------------------------------------------------===//
2270 // Passes
2271 //===---------------------------------------------------------------------===//
2272 
2273 namespace {
2274 
2275 /// Pass that fuses generic ops on tensors. Used only for testing.
2276 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2277 // patterns added here heavily depends on the cost function used. Having an
2278 // opinionated pass of this form is not recommended. Deprecate this pass in
2279 // favor of test passes that check the functionality of each of the patterns
2280 // added here individually.
2281 struct LinalgElementwiseOpFusionPass
2282  : public impl::LinalgElementwiseOpFusionPassBase<
2283  LinalgElementwiseOpFusionPass> {
2284  using impl::LinalgElementwiseOpFusionPassBase<
2285  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2286  void runOnOperation() override {
2287  Operation *op = getOperation();
2288  MLIRContext *context = op->getContext();
2289  RewritePatternSet patterns(context);
2290 
2291  // Add folding with reshape by expansion patterns.
2292  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2293  Operation *producer = fusedOperand->get().getDefiningOp();
2294  return producer && producer->hasOneUse();
2295  };
2296 
2297  // Add elementwise op fusion patterns.
2301 
2302  // General canonicalization patterns.
2303  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2304  GenericOp::getCanonicalizationPatterns(patterns, context);
2305  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2306  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2307  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2308  patterns);
2309 
2310  // Add constant folding patterns.
2312 
2313  // Use TopDownTraversal for compile time reasons.
2314  (void)applyPatternsGreedily(op, std::move(patterns),
2316  }
2317 };
2318 
2319 } // namespace
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:576
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:588
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1217
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:238
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1777
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:242
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1808
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:836
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:502
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:504
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.