MLIR  21.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Support/LLVM.h"
29 #include <optional>
30 #include <utility>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 //===---------------------------------------------------------------------===//
41 // Methods and patterns that fuse elementwise `linalg.generic` operations.
42 //===---------------------------------------------------------------------===//
43 
44 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45 /// the `producer` to use in the fused operation given the indexing map of the
46 /// result of the producer in the consumer.
48  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49  AffineMap fusedConsumerArgIndexMap) {
50  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51  // from consumer loop -> consumer arg tensor index/producer result tensor
52  // index. The fused loop is same as the consumer loop. For each producer arg
53  // the indexing map to be computed is a map from consumer loop -> producer
54  // arg tensor index.
55  // producerResultIndexMap is a map from producer loop -> tensor index.
56  // Compute the inverse to get map from tensor index -> producer loop.
57  // The inverse is a map from producer result tensor index -> producer loop.
58  AffineMap invProducerResultIndexMap =
59  inversePermutation(producerResultIndexMap);
60  assert(invProducerResultIndexMap &&
61  "expected producer result indexing map to be invertible");
62 
63  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64  // argMap is a map from producer loop -> producer arg tensor index.
65  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66 
67  // Compose argMap with invProducerResultIndexMap to get a map from
68  // producer result tensor index -> producer arg tensor index.
69  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70 
71  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72  // consumer loop/ fused loop -> producer arg tensor index.
73  return t1.compose(fusedConsumerArgIndexMap);
74 }
75 
76 // Checks if the given operand can be dropped, and the remaining operands
77 // of the fused producer & consumer after the fusion can still compute the
78 // bounds of the op.
80  GenericOp producer, GenericOp consumer,
81  ArrayRef<OpOperand *> opOperandsToIgnore) {
82  SmallVector<AffineMap> indexingMaps;
83 
84  SmallVector<GenericOp> ops = {producer, consumer};
85  for (auto &op : ops) {
86  for (auto &opOperand : op->getOpOperands()) {
87  if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88  continue;
89  }
90  indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91  }
92  }
93  if (indexingMaps.empty()) {
94  // If there are no indexing maps, the operand can only be dropped
95  // if neither op has loops.
96  return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97  }
98 
99  // The concatanation of the remained indexing maps must be invertible, so
100  // the bounds of the op can be still computed after dropping the selected
101  // operand. inversePermutation returns an empty AffineMap in case the
102  // concatanated indexing maps are not invertible.
104  indexingMaps, producer.getContext())) != AffineMap();
105 }
106 
107 /// Returns a set of indices of the producer's results which would
108 /// be preserved after the fusion.
109 /// * There is a chance that the implementation of the transformation does not
110 /// agree with the result of this method. This function gives a prediction based
111 /// on an optimized fusion.
113  GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114  llvm::SmallDenseSet<int> preservedProducerResults;
115  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116 
117  // The fusedOperand will be removed during the fusion
118  opOperandsToIgnore.emplace_back(fusedOperand);
119 
120  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122  opOperandsToIgnore.emplace_back(outputOperand);
123  if (producer.payloadUsesValueFromOperand(outputOperand) ||
124  !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
125  opOperandsToIgnore) ||
126  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127  return user != consumer.getOperation();
128  })) {
129  preservedProducerResults.insert(producerResult.index());
130 
131  // In case the operand can't be dropped
132  (void)opOperandsToIgnore.pop_back_val();
133  }
134  }
135  return preservedProducerResults;
136 }
137 
138 /// Conditions for elementwise fusion of generic operations.
140  if (!fusedOperand)
141  return false;
142 
143  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145 
146  // Check producer and consumer are generic ops.
147  if (!producer || !consumer)
148  return false;
149 
150  // Consumer can have mixed semantics, just check operand itself has tensor
151  // type. Producer must have full tensor semantics to avoid potential
152  // aliasing between producer and consumer memrefs.
153  if (!producer.hasPureTensorSemantics() ||
154  !isa<RankedTensorType>(fusedOperand->get().getType()))
155  return false;
156 
157  // Verify that
158  // - the producer has all "parallel" iterator type.
159  if (producer.getNumParallelLoops() != producer.getNumLoops())
160  return false;
161 
162  // Only allow fusing the producer of an input operand for now.
163  // TODO: allow fusing the producer of an output operand.
164  if (!consumer.isDpsInput(fusedOperand))
165  return false;
166 
167  // Get the consumer index map. The number of results of the consumer index
168  // map must match the number of loops of the producer.
169  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171  return false;
172 
173  // Finally the index_map for the result must be invertible. For now just
174  // verify it is a permutation.
175  AffineMap producerResultIndexMap =
176  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
177  if (!producerResultIndexMap.isPermutation())
178  return false;
179 
180  // Ensure that the fusion does not remove size information required to
181  // get the loop bounds. For non-reduction generics, this is trivially the
182  // case due to the output operand. For reductions, we need to check that after
183  // the fusion, each loop dimension has at least one input that defines it.
184  if ((consumer.getNumReductionLoops())) {
185  BitVector coveredDims(consumer.getNumLoops(), false);
186 
187  auto addToCoveredDims = [&](AffineMap map) {
188  for (auto result : map.getResults())
189  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
190  coveredDims[dimExpr.getPosition()] = true;
191  };
192 
193  for (auto pair :
194  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195  Value operand = std::get<0>(pair);
196  if (operand == fusedOperand->get())
197  continue;
198  AffineMap operandMap = std::get<1>(pair);
199  addToCoveredDims(operandMap);
200  }
201 
202  for (OpOperand *operand : producer.getDpsInputOperands()) {
203  AffineMap newIndexingMap =
205  operand, producerResultIndexMap, consumerIndexMap);
206  addToCoveredDims(newIndexingMap);
207  }
208  if (!coveredDims.all())
209  return false;
210  }
211 
212  return true;
213 }
214 
215 /// Generate the region of the fused tensor operation. The region of the fused
216 /// op must be empty.
218  RewriterBase &rewriter, GenericOp fusedOp,
219  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
220  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
221  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
222  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
223  // Build the region of the fused op.
224  Block &producerBlock = producer->getRegion(0).front();
225  Block &consumerBlock = consumer->getRegion(0).front();
226  OpBuilder::InsertionGuard guard(rewriter);
227  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
228  IRMapping mapper;
229 
230  // 2. Add an index operation for every fused loop dimension and use the
231  // `consumerToProducerLoopsMap` to map the producer indices.
232  if (producer.hasIndexSemantics()) {
233  // Add an index operation for every fused loop dimension.
234  unsigned numFusedOpLoops = fusedOp.getNumLoops();
235  SmallVector<Value> fusedIndices;
236  fusedIndices.reserve(numFusedOpLoops);
237  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
238  std::back_inserter(fusedIndices), [&](uint64_t dim) {
239  return rewriter.create<IndexOp>(producer.getLoc(), dim);
240  });
241  for (IndexOp indexOp :
242  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
243  Value newIndex = rewriter.create<affine::AffineApplyOp>(
244  producer.getLoc(),
245  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
246  mapper.map(indexOp.getResult(), newIndex);
247  }
248  }
249  // TODO: allow fusing the producer of an output operand.
250  assert(consumer.isDpsInput(fusedOperand) &&
251  "expected producer of input operand");
252  // 3. Consumer input operands up to consumerIdx (exclusive).
253  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
254  fusedOperand->getOperandNumber())) // input assumption.
255  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
256 
257  // Replacing consumerIdx requires getting the cloned, yielded, value from
258  // the (cloned) producer block. This happens in step 9.
259 
260  // 4. Splice in producer's input operands.
261  for (BlockArgument bbArg :
262  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
263  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
264 
265  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
266  for (BlockArgument bbArg :
267  consumerBlock.getArguments()
268  .take_front(consumer.getNumDpsInputs())
269  .drop_front(fusedOperand->getOperandNumber() + 1))
270  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
271 
272  // 6. All of the producer's output operands
273  for (const auto &bbArg : llvm::enumerate(
274  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
275  if (!preservedProducerResults.count(bbArg.index()))
276  continue;
277  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
278  bbArg.value().getLoc()));
279  }
280 
281  // 7. All of consumer's output operands.
282  for (BlockArgument bbArg :
283  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
284  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
285 
286  // 8. Clone all producer operations except for the yield and index operations
287  // to the fused operation.
288  for (auto &op : producerBlock.without_terminator()) {
289  if (!isa<IndexOp>(op))
290  rewriter.clone(op, mapper);
291  }
292  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
293  // forward the yield operand.
294  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
295  unsigned producerResultNumber =
296  cast<OpResult>(fusedOperand->get()).getResultNumber();
297  Value replacement =
298  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
299 
300  // Sanity checks, if replacement is not already in the mapper then it must be
301  // produced outside.
302  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
303  if (auto bb = dyn_cast<BlockArgument>(replacement))
304  assert(bb.getOwner() != &producerBlock &&
305  "yielded block argument must have been mapped");
306  else
307  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
308  "yielded value must have been mapped");
309  }
310  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
311  replacement);
312  // 10. Clone operations from the consumer to the fused op.
313  for (auto &op : consumerBlock.without_terminator())
314  rewriter.clone(op, mapper);
315 
316  // 11. Include the final yield (which is the remapped values for all the
317  // yield)
318  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
319  SmallVector<Value> fusedYieldValues;
320  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
321  consumerYieldOp.getNumOperands());
322  for (const auto &producerYieldVal :
323  llvm::enumerate(producerYieldOp.getOperands())) {
324  if (preservedProducerResults.count(producerYieldVal.index()))
325  fusedYieldValues.push_back(
326  mapper.lookupOrDefault(producerYieldVal.value()));
327  }
328  for (auto consumerYieldVal : consumerYieldOp.getOperands())
329  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
330  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
331 
332  // Sanity checks.
333  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
334  "Ill-formed GenericOp region");
335 }
336 
337 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
339  OpOperand *fusedOperand) {
340  assert(areElementwiseOpsFusable(fusedOperand) &&
341  "expected elementwise operation pre-conditions to pass");
342  auto producerResult = cast<OpResult>(fusedOperand->get());
343  auto producer = cast<GenericOp>(producerResult.getOwner());
344  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
345  // TODO: allow fusing the producer of an output operand.
346  assert(consumer.isDpsInput(fusedOperand) &&
347  "expected producer of input operand");
348  /// Find the results of the producer that have uses outside of the consumer,
349  /// after the fusion.
350  llvm::SmallDenseSet<int> preservedProducerResults =
352  fusedOperand);
353 
354  // Compute the fused operands list and indexing maps.
355  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
356  SmallVector<Type> fusedResultTypes;
357  SmallVector<AffineMap> fusedIndexMaps;
358  fusedInputOperands.reserve(producer.getNumDpsInputs() +
359  consumer.getNumDpsInputs());
360  fusedOutputOperands.reserve(preservedProducerResults.size() +
361  consumer.getNumDpsInits());
362  fusedResultTypes.reserve(preservedProducerResults.size() +
363  consumer.getNumDpsInits());
364  fusedIndexMaps.reserve(producer->getNumOperands() +
365  consumer->getNumOperands());
366  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
367  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
368  auto consumerInputs = consumer.getDpsInputOperands();
369  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
370  return operand == fusedOperand;
371  });
372  assert(it != consumerInputs.end() && "expected to find the consumer operand");
373  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
374  fusedInputOperands.push_back(opOperand->get());
375  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
376  }
377  // 4. Splice in producer's input operands/maps.
378  AffineMap producerResultIndexMap =
379  producer.getIndexingMapMatchingResult(producerResult);
380  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
381  fusedInputOperands.push_back(opOperand->get());
382  // Compute indexing maps for the producer args in the fused operation.
384  opOperand, producerResultIndexMap,
385  consumer.getMatchingIndexingMap(fusedOperand));
386  fusedIndexMaps.push_back(map);
387  }
388  // 5. Remaining consumer's input operands/maps (drop past index
389  // `consumerIdx`).
390  for (OpOperand *opOperand :
391  llvm::make_range(std::next(it), consumerInputs.end())) {
392  fusedInputOperands.push_back(opOperand->get());
393  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
394  }
395 
396  // 6. Collect all of the producer outputs.
397  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
398  if (!preservedProducerResults.count(opOperand.index()))
399  continue;
400 
401  fusedOutputOperands.push_back(opOperand.value().get());
403  &opOperand.value(), producerResultIndexMap,
404  consumer.getMatchingIndexingMap(fusedOperand));
405  fusedIndexMaps.push_back(map);
406  fusedResultTypes.push_back(opOperand.value().get().getType());
407  }
408 
409  // 7. All of consumer's output operands (skip operands: added by the builder).
410  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
411  fusedOutputOperands.push_back(opOperand.get());
412  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
413  Type resultType = opOperand.get().getType();
414  if (!isa<MemRefType>(resultType))
415  fusedResultTypes.push_back(resultType);
416  }
417 
418  // Generate the fused op.
419  auto fusedOp = rewriter.create<GenericOp>(
420  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
421  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
422  consumer.getIteratorTypes(),
423  /*doc=*/nullptr,
424  /*library_call=*/nullptr);
425  if (!fusedOp.getShapesToLoopsMap()) {
426  // Fused op has invalid indexing maps. Typically this means something is off
427  // in the input, but going ahead here would result in verification errors.
428  // So cleanup and abort.
429  rewriter.eraseOp(fusedOp);
430  return rewriter.notifyMatchFailure(
431  fusedOp, "fused op failed loop bound computation check");
432  }
433 
434  // Construct an AffineMap from consumer loops to producer loops.
435  // consumer loop -> tensor index
436  AffineMap consumerResultIndexMap =
437  consumer.getMatchingIndexingMap(fusedOperand);
438  // tensor index -> producer loop
439  AffineMap invProducerResultIndexMap =
440  inversePermutation(producerResultIndexMap);
441  assert(invProducerResultIndexMap &&
442  "expected producer result indexig map to be invertible");
443  // consumer loop -> producer loop
444  AffineMap consumerToProducerLoopsMap =
445  invProducerResultIndexMap.compose(consumerResultIndexMap);
446 
448  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
449  consumer.getNumLoops(), preservedProducerResults);
451  result.fusedOp = fusedOp;
452  int resultNum = 0;
453  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
454  if (preservedProducerResults.count(index))
455  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
456  for (auto consumerResult : consumer->getResults())
457  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
458  return result;
459 }
460 
461 namespace {
462 /// Patterns to fuse a generic op, with the producer of its operands.
463 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
464 public:
465  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
466  PatternBenefit benefit = 1)
467  : OpRewritePattern<GenericOp>(context, benefit),
468  controlFn(std::move(fun)) {}
469 
470  LogicalResult matchAndRewrite(GenericOp genericOp,
471  PatternRewriter &rewriter) const override {
472  // Find the first operand that is defined by another generic op on tensors.
473  for (OpOperand &opOperand : genericOp->getOpOperands()) {
474  if (!areElementwiseOpsFusable(&opOperand))
475  continue;
476  if (!controlFn(&opOperand))
477  continue;
478 
479  Operation *producer = opOperand.get().getDefiningOp();
480 
481  // Find the producer of the operand.
482  FailureOr<ElementwiseOpFusionResult> fusionResult =
483  fuseElementwiseOps(rewriter, &opOperand);
484  if (failed(fusionResult))
485  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
486 
487  // Perform the fusion.
488  for (auto [origVal, replacement] : fusionResult->replacements) {
489  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
490  // Only replace consumer uses.
491  return use.get().getDefiningOp() != producer;
492  });
493  }
494  rewriter.eraseOp(genericOp);
495  return success();
496  }
497  return failure();
498  }
499 
500 private:
501  ControlFusionFn controlFn;
502 };
503 } // namespace
504 
505 //===---------------------------------------------------------------------===//
506 // Methods and patterns that fuse reshape ops with elementwise operations by
507 // expanding the dimensionality of the elementwise operations.
508 //===---------------------------------------------------------------------===//
509 
510 /// Conditions for folding a structured linalg operation with a reshape op by
511 /// expanding the iteration space dimensionality for tensor operations. These
512 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
513 /// the following fusion pattern.
514 ///
515 /// Consider
516 ///
517 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
518 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
519 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
520 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
521 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
522 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
523 ///
524 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
525 /// is increased to match the result (operand) of the tensor.expand_shape.
526 /// The indexing_map of the fused tensor in the `linalgOp` and the
527 /// reassociation map helps compute the indexing maps of the modified op.
528 /// For the above example, based on the reassociation map it
529 /// can be concluded that
530 ///
531 /// - The loop used to access the first dimension of the fused tensor is split
532 /// into two.
533 /// - The loop used to access the second dimension of the fused tensor is kept
534 /// as is.
535 /// - The loop used to access the third dimension of the fused tensor is split
536 /// into three.
537 ///
538 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
539 /// op, then
540 ///
541 /// d0 -> e0, e1
542 /// d1 -> e2, e3, e4
543 /// d2 -> e5
544 ///
545 /// substituting this, the structured op can be rewritten as
546 ///
547 /// %d = linalg.generic ins(%0, %1 : )
548 /// indexing_maps =
549 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
550 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
551 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
552 ///
553 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
554 /// to make it consistent
555 ///
556 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
557 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
558 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
559 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
560 ///
561 /// The added reshapes are again expanding patterns, so they will get fused
562 /// with its producers if possible.
563 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
564  OpOperand *fusableOpOperand) {
565  // Is fusable only if:
566  // - All the indexing maps for operands and results are projected
567  // permutations.
568  // - The fused tensor is not a scalar.
569  SmallVector<utils::IteratorType> iteratorTypes =
570  linalgOp.getIteratorTypesArray();
571  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
572  return linalgOp.hasPureTensorSemantics() &&
573  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
574  [](Attribute attr) {
575  return cast<AffineMapAttr>(attr)
576  .getValue()
577  .isProjectedPermutation();
578  }) &&
579  operandMap.getNumResults() > 0;
580 }
581 
582 namespace {
583 /// Information needed to expand a generic operation to fold the reshape with
584 /// it.
585 class ExpansionInfo {
586 public:
587  // Computes the mapping from original dimensions of the op to the dimensions
588  // of the expanded op given the `indexingMap` of the fused operand/result of
589  // the generic op, the `reassocationMaps` of the reshape op and the shape of
590  // the expanded op.
591  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
592  ArrayRef<AffineMap> reassociationMaps,
593  ArrayRef<OpFoldResult> expandedShape,
594  PatternRewriter &rewriter);
595  unsigned getOrigOpNumDims() const { return reassociation.size(); }
596  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
597  ReassociationIndicesRef getExpandedDims(unsigned i) const {
598  return reassociation[i];
599  }
600  ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
601  return expandedShapeMap[i];
602  }
603  ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
604 
605 private:
606  /// Reassociation from the dimensions in the original operation to the
607  /// dimension of the expanded operation.
608  SmallVector<ReassociationIndices> reassociation;
609  /// Mapping from extent of loops in the original operation, to the extent of
610  /// loops in the expanded operation.
611  SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
612  /// Extent of the loop in the original operation.
613  SmallVector<OpFoldResult> originalLoopExtent;
614  unsigned expandedOpNumDims;
615 };
616 } // namespace
617 
618 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
619  OpOperand *fusableOpOperand,
620  ArrayRef<AffineMap> reassociationMaps,
621  ArrayRef<OpFoldResult> expandedShape,
622  PatternRewriter &rewriter) {
623  if (reassociationMaps.empty())
624  return failure();
625  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
626 
627  OpBuilder::InsertionGuard g(rewriter);
628  rewriter.setInsertionPoint(linalgOp);
629  originalLoopExtent = llvm::map_to_vector(
630  linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
631  [](Range r) { return r.size; });
632 
633  reassociation.clear();
634  expandedShapeMap.clear();
635  // Compute the number of dimension in the expanded op that correspond to each
636  // dimension of the original op.
637  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
638  expandedShapeMap.resize(fusedIndexMap.getNumDims());
639  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
640  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
641  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
642  numExpandedDims[pos] = foldedDims.getNumResults();
643  ArrayRef<OpFoldResult> shape =
644  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
645  expandedShapeMap[pos].assign(shape.begin(), shape.end());
646  }
647  // The remaining dimensions remain the same.
648  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
649  if (expandedShapeMap[i].empty())
650  expandedShapeMap[i] = {originalLoopExtent[i]};
651 
652  // Compute reassociation map from the original op to the expanded op.
653  unsigned sum = 0;
654  reassociation.reserve(fusedIndexMap.getNumDims());
655  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
656  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
657  reassociation.emplace_back(seq.begin(), seq.end());
658  sum += numFoldedDim.value();
659  }
660  expandedOpNumDims = sum;
661  return success();
662 }
663 
664 /// Return the indexing map to use in the expanded op for a given the
665 /// `indexingMap` of the original operation.
666 static AffineMap
668  const ExpansionInfo &expansionInfo) {
669  SmallVector<AffineExpr> newExprs;
670  for (AffineExpr expr : indexingMap.getResults()) {
671  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
672  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
673  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
674  return builder.getAffineDimExpr(static_cast<unsigned>(v));
675  }));
676  newExprs.append(expandedExprs.begin(), expandedExprs.end());
677  }
678  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
679  indexingMap.getNumSymbols(), newExprs,
680  builder.getContext());
681 }
682 
683 /// Return the shape and type of the operand/result to use in the expanded op
684 /// given the type in the original op.
685 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
686 getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
687  const ExpansionInfo &expansionInfo) {
688  SmallVector<OpFoldResult> expandedShape;
689  for (AffineExpr expr : indexingMap.getResults()) {
690  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
691  ArrayRef<OpFoldResult> dimExpansion =
692  expansionInfo.getExpandedShapeOfDim(dim);
693  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
694  }
695  SmallVector<int64_t> expandedStaticShape;
696  std::tie(expandedStaticShape, std::ignore) =
697  decomposeMixedValues(expandedShape);
698  return {expandedShape, RankedTensorType::get(expandedStaticShape,
699  originalType.getElementType())};
700 }
701 
702 /// Returns the reassociation maps to use in the `tensor.expand_shape`
703 /// operation to convert the operands of the original operation to operands of
704 /// the expanded operation. The same method is used to compute the
705 /// `tensor.collapse_shape` used to collapse the result of the expanded
706 /// op to get the value that can replace all uses of the results of the original
707 /// op.
710  const ExpansionInfo &expansionInfo) {
711  SmallVector<ReassociationIndices> reassociation;
712  unsigned numReshapeDims = 0;
713  for (AffineExpr expr : indexingMap.getResults()) {
714  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
715  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
716  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
717  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
718  reassociation.emplace_back(std::move(indices));
719  numReshapeDims += numExpandedDims;
720  }
721  return reassociation;
722 }
723 
724 /// Update the body of an expanded linalg operation having index semantics. The
725 /// indices of the original operation need to be recovered by linearizing the
726 /// indices of the correspoding dimensions of the expanded operation. For now it
727 /// is assumed that the shapes of the expanded operation needed for
728 /// linearization are static.
730  Location loc, Region &fusedRegion,
731  const ExpansionInfo &expansionInfo) {
732  // Replace the original indices by the linearization of the expanded indices.
733  for (IndexOp indexOp :
734  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
735  ArrayRef<int64_t> expandedDims =
736  expansionInfo.getExpandedDims(indexOp.getDim());
737  assert(!expandedDims.empty() && "expected valid expansion info");
738 
739  // Skip index operations that are not affected by the expansion.
740  if (expandedDims.size() == 1 &&
741  expandedDims.front() == (int64_t)indexOp.getDim())
742  continue;
743 
744  // Linearize the expanded indices of the original index dimension.
745  OpBuilder::InsertionGuard guard(rewriter);
746  rewriter.setInsertionPointAfter(indexOp);
747  ArrayRef<OpFoldResult> expandedDimsShape =
748  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
749  SmallVector<Value> expandedIndices;
750  expandedIndices.reserve(expandedDims.size() - 1);
751  llvm::transform(
752  expandedDims.drop_front(), std::back_inserter(expandedIndices),
753  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
754  OpFoldResult newIndex =
755  rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
756  for (auto [expandedShape, expandedIndex] :
757  llvm::zip(expandedDimsShape, expandedIndices)) {
758  AffineExpr idx, acc, shape;
759  bindDims(rewriter.getContext(), idx, acc);
760  bindSymbols(rewriter.getContext(), shape);
762  rewriter, indexOp.getLoc(), idx + acc * shape,
763  ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
764  }
765  Value newIndexVal =
766  getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
767  rewriter.replaceOp(indexOp, newIndexVal);
768  }
769 }
770 
771 // Create an expanded transpose op.
772 // the reassociation map is already permuted hence we inverse permute and then
773 // flatten it. Then we inverse permute it again to get the final expanded
774 // transpose permutation. For example,
775 //
776 // permutation = [2, 0, 1]
777 // reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
778 //
779 // inverse permutation = [1, 2, 0]
780 // applied to reassocation_map and then flattened becomes
781 // flatened permutation = [2, 3, 4, 5, 0, 1]
782 // final permuation is the inverse of the flattened permutation.
783 //
784 // Becomes
785 //
786 // permutation=[4, 5, 0, 1, 2, 3]
787 
789  TransposeOp transposeOp,
790  Value expandedInput, Value output,
791  ExpansionInfo &expansionInfo) {
792  SmallVector<int64_t> newPerm;
793  for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
794  auto reassoc = expansionInfo.getExpandedDims(perm);
795  for (int64_t dim : reassoc) {
796  newPerm.push_back(dim);
797  }
798  }
799  return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
800  output, invertPermutationVector(newPerm));
801 }
802 
803 // Create an expanded generic op.
805  PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
806  ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
807  ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
808  // The iterator types of the expanded op are all parallel.
809  SmallVector<utils::IteratorType> iteratorTypes(
810  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
811 
812  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
813  for (auto j : expansionInfo.getExpandedDims(i))
814  iteratorTypes[j] = type;
815 
816  Operation *fused = rewriter.create<GenericOp>(
817  linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
818  expandedOpIndexingMaps, iteratorTypes);
819 
820  Region &fusedRegion = fused->getRegion(0);
821  Region &originalRegion = linalgOp->getRegion(0);
822  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
823 
824  // Update the index accesses after the expansion.
825  updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
826  expansionInfo);
827 
828  return fused;
829 }
830 
831 // Create an expanded fused op that retains the name for certain ops
832 // such as fill, copy and transpose and produce a generic op for
833 // rest of linalg ops.
834 static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
835  TypeRange resultTypes,
836  ArrayRef<Value> expandedOpOperands,
837  ArrayRef<Value> outputs,
838  ArrayRef<AffineMap> expandedOpIndexingMaps,
839  ExpansionInfo &expansionInfo) {
840 
841  return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
842  .Case<TransposeOp>([&](TransposeOp transposeOp) {
843  return createExpandedTransposeOp(rewriter, transposeOp,
844  expandedOpOperands[0], outputs[0],
845  expansionInfo);
846  })
847  .Case<FillOp, CopyOp>([&](Operation *op) {
848  return clone(rewriter, linalgOp, resultTypes,
849  llvm::to_vector(llvm::concat<Value>(
850  llvm::to_vector(expandedOpOperands),
851  llvm::to_vector(outputs))));
852  })
853  .Default([&](Operation *op) {
854  return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
855  expandedOpOperands, outputs,
856  expansionInfo, expandedOpIndexingMaps);
857  });
858 }
859 
860 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
861 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
862 /// that those conditions have been satisfied.
863 static std::optional<SmallVector<Value>>
864 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
865  OpOperand *fusableOpOperand,
866  PatternRewriter &rewriter) {
867  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
868  "preconditions for fuse operation failed");
869 
870  Location loc = linalgOp.getLoc();
871  SmallVector<OpFoldResult> expandedShape;
872  SmallVector<AffineMap, 4> reassociationIndices;
873  Value src;
874  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
875  // Try to move the dynamic dimensions in output shape before the `linalgOp`
876  // to maintain SSA validity
877  if (failed(moveValueDefinitions(
878  rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
879  return std::nullopt;
880 
881  expandedShape = expandingReshapeOp.getMixedOutputShape();
882  reassociationIndices = expandingReshapeOp.getReassociationMaps();
883  src = expandingReshapeOp.getSrc();
884  } else {
885  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
886  if (!collapsingReshapeOp)
887  return std::nullopt;
888 
889  expandedShape = tensor::getMixedSizes(
890  rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
891  reassociationIndices = collapsingReshapeOp.getReassociationMaps();
892  src = collapsingReshapeOp.getSrc();
893  }
894 
895  ExpansionInfo expansionInfo;
896  if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
897  reassociationIndices, expandedShape,
898  rewriter)))
899  return std::nullopt;
900 
901  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
902  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
903  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
904  }));
905 
906  // Set insertion point to the generic op.
907  OpBuilder::InsertionGuard g(rewriter);
908  rewriter.setInsertionPoint(linalgOp);
909 
910  SmallVector<Value> expandedOpOperands;
911  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
912  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
913  if (opOperand == fusableOpOperand) {
914  expandedOpOperands.push_back(src);
915  continue;
916  }
917  if (auto opOperandType =
918  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
919  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
920  SmallVector<OpFoldResult> expandedOperandShape;
921  RankedTensorType expandedOperandType;
922  std::tie(expandedOperandShape, expandedOperandType) =
923  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
924  if (expandedOperandType != opOperand->get().getType()) {
925  // Reshape the operand to get the right type.
926  SmallVector<ReassociationIndices> reassociation =
927  getReassociationForExpansion(indexingMap, expansionInfo);
929  [&](const Twine &msg) {
930  return rewriter.notifyMatchFailure(linalgOp, msg);
931  },
932  opOperandType.getShape(), expandedOperandType.getShape(),
933  reassociation,
934  /*isExpandingReshape=*/true)))
935  return std::nullopt;
936  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
937  loc, expandedOperandType, opOperand->get(), reassociation,
938  expandedOperandShape));
939  continue;
940  }
941  }
942  expandedOpOperands.push_back(opOperand->get());
943  }
944 
945  SmallVector<Value> outputs;
946  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
947  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
948  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
949  SmallVector<OpFoldResult> expandedOutputShape;
950  RankedTensorType expandedOutputType;
951  std::tie(expandedOutputShape, expandedOutputType) =
952  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
953  if (expandedOutputType != opOperand.get().getType()) {
954  SmallVector<ReassociationIndices> reassociation =
955  getReassociationForExpansion(indexingMap, expansionInfo);
957  [&](const Twine &msg) {
958  return rewriter.notifyMatchFailure(linalgOp, msg);
959  },
960  opOperandType.getShape(), expandedOutputType.getShape(),
961  reassociation,
962  /*isExpandingReshape=*/true)))
963  return std::nullopt;
964  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
965  loc, expandedOutputType, opOperand.get(), reassociation,
966  expandedOutputShape));
967  } else {
968  outputs.push_back(opOperand.get());
969  }
970  }
971 
972  TypeRange resultTypes = ValueRange(outputs).getTypes();
973  Operation *fusedOp =
974  createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
975  outputs, expandedOpIndexingMaps, expansionInfo);
976  // Reshape the result values to their original shape if this is a collapsing
977  // reshape folded into its consumer.
978  SmallVector<Value> resultVals;
979  for (OpResult opResult : linalgOp->getOpResults()) {
980  int64_t resultNumber = opResult.getResultNumber();
981  if (resultTypes[resultNumber] != opResult.getType()) {
982  SmallVector<ReassociationIndices> reassociation =
984  linalgOp.getMatchingIndexingMap(
985  linalgOp.getDpsInitOperand(resultNumber)),
986  expansionInfo);
987  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
988  linalgOp.getLoc(), opResult.getType(),
989  fusedOp->getResult(resultNumber), reassociation));
990  } else {
991  resultVals.push_back(fusedOp->getResult(resultNumber));
992  }
993  }
994  // Assuming a single result.
995  return resultVals;
996 }
997 
998 namespace {
999 
1000 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1001 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1002 /// in the consumer is expanded.
1003 class FoldWithProducerReshapeOpByExpansion
1004  : public OpInterfaceRewritePattern<LinalgOp> {
1005 public:
1006  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1007  ControlFusionFn foldReshapes,
1008  PatternBenefit benefit = 1)
1009  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1010  controlFoldingReshapes(std::move(foldReshapes)) {}
1011 
1012  LogicalResult matchAndRewrite(LinalgOp linalgOp,
1013  PatternRewriter &rewriter) const override {
1014  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1015  tensor::CollapseShapeOp reshapeOp =
1016  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1017  if (!reshapeOp)
1018  continue;
1019  // Fold only if
1020  // - The tensor reshape op is folding.
1021  // - All constraints of fusing with reshape by expansion are met.
1022  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1023  (!controlFoldingReshapes(opOperand)))
1024  continue;
1025 
1026  std::optional<SmallVector<Value>> replacementValues =
1027  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1028  if (!replacementValues)
1029  return failure();
1030  rewriter.replaceOp(linalgOp, *replacementValues);
1031  return success();
1032  }
1033  return failure();
1034  }
1035 
1036 private:
1037  ControlFusionFn controlFoldingReshapes;
1038 };
1039 
1040 class FoldPadWithProducerReshapeOpByExpansion
1041  : public OpRewritePattern<tensor::PadOp> {
1042 public:
1043  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1044  ControlFusionFn foldReshapes,
1045  PatternBenefit benefit = 1)
1046  : OpRewritePattern<tensor::PadOp>(context, benefit),
1047  controlFoldingReshapes(std::move(foldReshapes)) {}
1048 
1049  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1050  PatternRewriter &rewriter) const override {
1051  tensor::CollapseShapeOp reshapeOp =
1052  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1053  if (!reshapeOp)
1054  return failure();
1055  if (!reshapeOp->hasOneUse())
1056  return failure();
1057 
1058  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1059  return rewriter.notifyMatchFailure(padOp,
1060  "fusion blocked by control function");
1061  }
1062 
1063  ArrayRef<int64_t> low = padOp.getStaticLow();
1064  ArrayRef<int64_t> high = padOp.getStaticHigh();
1065  SmallVector<ReassociationIndices> reassociations =
1066  reshapeOp.getReassociationIndices();
1067 
1068  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1069  if (reInd.size() != 1 && (l != 0 || h != 0))
1070  return failure();
1071  }
1072 
1073  SmallVector<OpFoldResult> newLow, newHigh;
1074  RankedTensorType expandedType = reshapeOp.getSrcType();
1075  RankedTensorType paddedType = padOp.getResultType();
1076  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1077  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1078  if (reInd.size() == 1) {
1079  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1080  }
1081  for (size_t i = 0; i < reInd.size(); ++i) {
1082  newLow.push_back(padOp.getMixedLowPad()[idx]);
1083  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1084  }
1085  }
1086 
1087  Location loc = padOp->getLoc();
1088  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1089  auto newPadOp = rewriter.create<tensor::PadOp>(
1090  loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1091  padOp.getConstantPaddingValue(), padOp.getNofold());
1092 
1093  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1094  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1095 
1096  return success();
1097  }
1098 
1099 private:
1100  ControlFusionFn controlFoldingReshapes;
1101 };
1102 
1103 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1104 /// by expanding the dimensionality of the loop in the producer op.
1105 struct FoldReshapeWithGenericOpByExpansion
1106  : public OpRewritePattern<tensor::ExpandShapeOp> {
1107 
1108  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1109  ControlFusionFn foldReshapes,
1110  PatternBenefit benefit = 1)
1111  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1112  controlFoldingReshapes(std::move(foldReshapes)) {}
1113 
1114  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1115  PatternRewriter &rewriter) const override {
1116  // Fold only if all constraints of fusing with reshape by expansion are met.
1117  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1118  if (!producerResult) {
1119  return rewriter.notifyMatchFailure(reshapeOp,
1120  "source not produced by an operation");
1121  }
1122 
1123  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1124  if (!producer) {
1125  return rewriter.notifyMatchFailure(reshapeOp,
1126  "producer not a generic op");
1127  }
1128 
1130  producer,
1131  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1132  return rewriter.notifyMatchFailure(
1133  reshapeOp, "failed preconditions of fusion with producer generic op");
1134  }
1135 
1136  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1137  return rewriter.notifyMatchFailure(reshapeOp,
1138  "fusion blocked by control function");
1139  }
1140 
1141  std::optional<SmallVector<Value>> replacementValues =
1143  producer, reshapeOp,
1144  producer.getDpsInitOperand(producerResult.getResultNumber()),
1145  rewriter);
1146  if (!replacementValues) {
1147  return rewriter.notifyMatchFailure(reshapeOp,
1148  "fusion by expansion failed");
1149  }
1150 
1151  // Find the replacement for the reshape op. Since the replacements have the
1152  // same type as the returns of the original generic op, the consumer reshape
1153  // op can be replaced by the source of the collapse_shape op that defines
1154  // the replacement.
1155  Value reshapeReplacement =
1156  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1157  .getResultNumber()];
1158  if (auto collapseOp =
1159  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1160  reshapeReplacement = collapseOp.getSrc();
1161  }
1162  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1163  rewriter.replaceOp(producer, *replacementValues);
1164  return success();
1165  }
1166 
1167 private:
1168  ControlFusionFn controlFoldingReshapes;
1169 };
1170 } // namespace
1171 
1172 //===---------------------------------------------------------------------===//
1173 // Methods and patterns to fuse reshape with linalg.generic operations by
1174 // contraction of dimensions.
1175 //===---------------------------------------------------------------------===//
1176 
1177 /// For a given list of indices in the range of the `indexingMap` that are
1178 /// folded, return the indices of the corresponding domain. Return
1179 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1180 /// reassociation are distinct.
1181 static ReassociationIndices
1183  ReassociationIndicesRef rangeReassociation) {
1184  assert(indexingMap.isProjectedPermutation() &&
1185  "expected projected permutation");
1186 
1187  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1188  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1189  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1190  }));
1191  // The projected permutation semantics ensures that there is no repetition of
1192  // the domain indices.
1193  return domainReassociation;
1194 }
1195 
1196 /// For a given `dimSequence`, check if the sequence is conserved in the
1197 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1198 /// Non-existence of the sequence returns true as well.
1200  ReassociationIndicesRef dimSequence) {
1201  assert(!dimSequence.empty() &&
1202  "expected non-empty list for dimension sequence");
1203  assert(indexingMap.isProjectedPermutation() &&
1204  "expected indexing map to be projected permutation");
1205 
1206  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1207  sequenceElements.insert_range(dimSequence);
1208 
1209  unsigned dimSequenceStart = dimSequence[0];
1210  for (const auto &expr : enumerate(indexingMap.getResults())) {
1211  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1212  // 1. Check if this start of the sequence.
1213  if (dimInMapStart == dimSequenceStart) {
1214  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1215  return false;
1216  // 1a. Check if sequence is preserved.
1217  for (const auto &dimInSequence : enumerate(dimSequence)) {
1218  unsigned dimInMap =
1219  cast<AffineDimExpr>(
1220  indexingMap.getResult(expr.index() + dimInSequence.index()))
1221  .getPosition();
1222  if (dimInMap != dimInSequence.value())
1223  return false;
1224  }
1225  // Found the sequence. Projected permutation
1226  // enforces that all AffineDimExprs in the result are unique, so no
1227  // further checks are needed.
1228  return true;
1229  }
1230  // 2. If position in the expr (which is of type AffineDimExpr) is part
1231  // of sequence, return false here. This implies the entire sequence does not
1232  // exist in the indexing map.
1233  if (sequenceElements.count(dimInMapStart))
1234  return false;
1235  }
1236  // 3. No element of sequence found. Return true.
1237  return true;
1238 }
1239 
1242  return llvm::all_of(maps, [&](AffineMap map) {
1243  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1244  return isDimSequencePreserved(map, dimSequence);
1245  });
1246  });
1247 }
1248 
1249 // Return the list of dimensions of the iteration domain that can be
1250 // collapsed to allow for fusion with the a producer that is an expand_shape
1251 // operation. If all dimensions created by expansion can be collapsed in the
1252 // iteration space then the reshape is defunct.
1253 //
1254 // Example:
1255 //
1256 // ```mlir
1257 // #map = affine_map<(d0, d1) -> (d0, d1)>
1258 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1259 // %2 = tensor.empty [..] : tensor<?x4xf32>
1260 // %3 = linalg.generic {
1261 // indexing_maps = [#map, #map],
1262 // iterator_types = ["parallel" ,"parallel"]}
1263 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1264 // ```
1265 //
1266 // can be fused by collapsing the dimensions of the iteration space.
1267 //
1268 // ```mlir
1269 // #map = affine_map<(d0) -> (d0)>
1270 // %2 = tensor.empty [..] : tensor<?xf32>
1271 // %3 = linalg.generic {
1272 // indexing_maps = [#map, #map],
1273 // iterator_types = ["parallel"]}
1274 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1275 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1276 // ```
1277 //
1278 // In the following example,
1279 //
1280 // ```mlir
1281 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1282 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1283 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1284 // %2 = tensor.empty [..] : tensor<4x?xf32>
1285 // %2 = linalg.generic {
1286 // indexing_maps = [#map0, #map1],
1287 // iterator_types = ["parallel" ,"parallel"]}
1288 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1289 // ```
1290 //
1291 // the reshape cannot be fused with the generic op by collapsing the op
1292 // dimensions since the indexing maps will have to contain mods and divs
1293 // to preserve the accesses pattern. When no dimensions of the iteration
1294 // space are collapsable and empty vector is returned.
1296 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1297  ArrayRef<ReassociationIndices> reassociation) {
1298  // Some basic checks for this fusion to be valid.
1299  if (!genericOp.hasPureTensorSemantics())
1300  return {};
1301 
1302  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1303  return map.isProjectedPermutation();
1304  })) {
1305  return {};
1306  }
1307 
1308  // Compute all the loops with the reduction iterator types.
1309  SmallVector<unsigned> reductionDims;
1310  genericOp.getReductionDims(reductionDims);
1311 
1312  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1313  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1314  auto iteratorTypes = genericOp.getIteratorTypesArray();
1315  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1316  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1317  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1318 
1319  // Ignore dims that are not folded.
1320  if (foldedRangeDims.size() == 1)
1321  continue;
1322 
1323  ReassociationIndices foldedIterationSpaceDims =
1324  getDomainReassociation(indexingMap, foldedRangeDims);
1325 
1326  // Check that the folded iteration dims do not contain already processed
1327  // dims.
1328  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1329  return processedIterationDims.count(dim);
1330  }))
1331  continue;
1332 
1333  // Check that all folded iterator types are all parallel or all reductions.
1334  utils::IteratorType startIteratorType =
1335  iteratorTypes[foldedIterationSpaceDims[0]];
1336  if (!isParallelIterator(startIteratorType) &&
1337  !isReductionIterator(startIteratorType))
1338  continue;
1339  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1340  return iteratorTypes[dim] != startIteratorType;
1341  }))
1342  continue;
1343 
1344  // If the folded dimensions correspond to a "reduction" iterator type,
1345  // the folded dimensions need to be "in-order". Strictly speaking this is
1346  // not necessary, for reductions that are associative and commutative, but
1347  // using a more strict definition of reduction for now.
1348  if (isReductionIterator(startIteratorType)) {
1349  bool isContiguous = false;
1350  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1351  // Move window in `reductionDims` to start of the folded iteration dims.
1352  if (startDim.value() != foldedIterationSpaceDims[0])
1353  continue;
1354  // If sizes doesnt match, trivial not contiguous. This condition should
1355  // not be hit.
1356  if (startDim.index() + foldedIterationSpaceDims.size() >
1357  reductionDims.size())
1358  break;
1359  // Check that the contiguity is maintained.
1360  isContiguous = true;
1361  for (const auto &foldedDim :
1362  llvm::enumerate(foldedIterationSpaceDims)) {
1363  if (reductionDims[foldedDim.index() + startDim.index()] !=
1364  foldedDim.value()) {
1365  isContiguous = false;
1366  break;
1367  }
1368  }
1369  break;
1370  }
1371  if (!isContiguous)
1372  continue;
1373  }
1374 
1375  // Check that the sequence is preserved in all indexing maps.
1376  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1377  [&](AffineMap indexingMap) {
1378  return !isDimSequencePreserved(indexingMap,
1379  foldedIterationSpaceDims);
1380  }))
1381  continue;
1382 
1383  processedIterationDims.insert_range(foldedIterationSpaceDims);
1384  iterationSpaceReassociation.emplace_back(
1385  std::move(foldedIterationSpaceDims));
1386  }
1387 
1388  return iterationSpaceReassociation;
1389 }
1390 
1391 /// Helper class to carry state while collapsing the `linalg.generic` op.
1392 namespace {
1393 class CollapsingInfo {
1394 public:
1395  LogicalResult initialize(unsigned origNumLoops,
1396  ArrayRef<ReassociationIndices> foldedIterationDims) {
1397  llvm::SmallDenseSet<int64_t, 4> processedDims;
1398  // Find all the dims that are folded.
1399  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1400  if (foldedIterationDim.empty())
1401  continue;
1402  // If the folded dims contain dims already folded, that's illegal
1403  // specification. Repetition within a list is also illegal.
1404  for (auto dim : foldedIterationDim) {
1405  if (dim >= origNumLoops)
1406  return failure();
1407  if (processedDims.count(dim))
1408  return failure();
1409  processedDims.insert(dim);
1410  }
1411  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1412  foldedIterationDim.end());
1413  }
1414  if (processedDims.size() > origNumLoops)
1415  return failure();
1416 
1417  // Add all the preserved dims of the original op as single
1418  // elements to `collapsedOpToOrigOpIterationDim`.
1419  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1420  if (processedDims.count(dim))
1421  continue;
1422  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1423  }
1424 
1425  llvm::sort(collapsedOpToOrigOpIterationDim,
1427  return lhs[0] < rhs[0];
1428  });
1429  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1430  for (const auto &foldedDims :
1431  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1432  for (const auto &dim : enumerate(foldedDims.value()))
1433  origOpToCollapsedOpIterationDim[dim.value()] =
1434  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1435  }
1436  return success();
1437  }
1438 
1439  /// Return mapping from collapsed loop domain to original loop domain.
1440  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1441  return collapsedOpToOrigOpIterationDim;
1442  }
1443 
1444  /// Return mapping from original loop domain to collapsed loop domain. The
1445  /// mapping is a pair. First value is the dimension in the collapsed loop that
1446  /// the original loop is mapped to. Second is the relative position in folded
1447  /// list of this domain. For example if the original loop domain is 3D, and
1448  /// the collapsed loop domain is folding all of it, i.e.
1449  ///
1450  /// ```
1451  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1452  /// ```
1453  ///
1454  /// then
1455  ///
1456  /// ```
1457  /// origOpToCollapsedOpMapping[0] = {0, 0};
1458  /// origOpToCollapsedOpMapping[1] = {0, 1};
1459  /// origOpToCollapsedOpMapping[2] = {0, 2};
1460  /// origOpToCollapsedOpMapping[3] = {1, 0};
1461  /// origOpToCollapsedOpMapping[4] = {1, 1};
1462  /// ```
1463  ///
1464  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1465  return origOpToCollapsedOpIterationDim;
1466  }
1467 
1468  /// Return the collapsed op iteration domain rank.
1469  unsigned getCollapsedOpIterationRank() const {
1470  return collapsedOpToOrigOpIterationDim.size();
1471  }
1472 
1473 private:
1474  /// Map from the iteration domain index in collapsed op to the iteration
1475  /// domain indices in the original op.
1476  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1477 
1478  /// Map from iteration domain index in the original op to the iteration domain
1479  /// index in the collapsed op.
1480  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1481 };
1482 } // namespace
1483 
1484 /// Get the iterator types for the collapsed operation given the original
1485 /// iterator types and collapsed dimensions.
1488  const CollapsingInfo &collapsingInfo) {
1489  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1490  for (ReassociationIndicesRef foldedIterDims :
1491  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1492  assert(!foldedIterDims.empty() &&
1493  "reassociation indices expected to have non-empty sets");
1494  // Just pick the iterator type of the first folded dim. Pre-condition checks
1495  // expected to have checked that iterator types of all folded dimensions are
1496  // the same.
1497  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1498  }
1499  return collapsedIteratorTypes;
1500 }
1501 
1502 /// Compute the indexing map in the collapsed op that corresponds to the given
1503 /// `indexingMap` of the original operation.
1504 static AffineMap
1506  const CollapsingInfo &collapsingInfo) {
1507  MLIRContext *context = indexingMap.getContext();
1508  assert(indexingMap.isProjectedPermutation() &&
1509  "expected indexing map to be projected permutation");
1510  SmallVector<AffineExpr> resultExprs;
1511  auto origOpToCollapsedOpMapping =
1512  collapsingInfo.getOrigOpToCollapsedOpMapping();
1513  for (auto expr : indexingMap.getResults()) {
1514  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1515  // If the dim is not the first of the collapsed dim, do nothing.
1516  if (origOpToCollapsedOpMapping[dim].second != 0)
1517  continue;
1518  // The next n-dims are guaranteed to be collapsed. So just use the
1519  // iteration dimension of the collapsed op.
1520  resultExprs.push_back(
1521  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1522  }
1523  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1524  resultExprs, context);
1525 }
1526 
1527 /// Return the `reassociation` indices to use to collapse the operand when the
1528 /// iteration space of a generic op is collapsed.
1531  const CollapsingInfo &collapsingInfo) {
1532  unsigned counter = 0;
1533  SmallVector<ReassociationIndices> operandReassociation;
1534  auto origOpToCollapsedOpMapping =
1535  collapsingInfo.getOrigOpToCollapsedOpMapping();
1536  auto collapsedOpToOrigOpMapping =
1537  collapsingInfo.getCollapsedOpToOrigOpMapping();
1538  while (counter < indexingMap.getNumResults()) {
1539  unsigned dim =
1540  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1541  // This is the start of a collapsed dimensions of the iteration that
1542  // is gauranteed to be preserved in the indexing map. The number of folded
1543  // dims is obtained from the collapsed op to original op mapping.
1544  unsigned numFoldedDims =
1545  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1546  .size();
1547  if (origOpToCollapsedOpMapping[dim].second == 0) {
1548  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1549  operandReassociation.emplace_back(range.begin(), range.end());
1550  }
1551  counter += numFoldedDims;
1552  }
1553  return operandReassociation;
1554 }
1555 
1556 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1557 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1558  OpOperand *opOperand,
1559  const CollapsingInfo &collapsingInfo,
1560  OpBuilder &builder) {
1561  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1562  SmallVector<ReassociationIndices> operandReassociation =
1563  getOperandReassociation(indexingMap, collapsingInfo);
1564 
1565  // If the number of entries in the reassociation for the operand is same as
1566  // the number of results of the indexing map, then nothing to do for this
1567  // operand.
1568  Value operand = opOperand->get();
1569  if (operandReassociation.size() == indexingMap.getNumResults())
1570  return operand;
1571 
1572  // Insert a reshape to collapse the dimensions.
1573  if (isa<MemRefType>(operand.getType())) {
1574  return builder
1575  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1576  .getResult();
1577  }
1578  return builder
1579  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1580  .getResult();
1581 }
1582 
1583 /// Modify the `linalg.index` operations in the original generic op, to its
1584 /// value in the collapsed operation.
1586  Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1587  ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1588  OpBuilder::InsertionGuard g(rewriter);
1589  rewriter.setInsertionPointToStart(block);
1590 
1591  // Collect all the original index ops.
1592  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1593 
1594  // For each folded dimension list resolve the original induction variable
1595  // values in terms of the folded dimension induction variable.
1596  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1597  // can be inverted to
1598  // i2 = i_{folded} % d2
1599  // i1 = (i_{folded} / d2) % d1
1600  // i0 = i_{folded} / (d1 * d2)
1601  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1602  for (auto foldedDims :
1603  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1604  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1605  Value newIndexVal =
1606  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1607  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1608  Value loopDim =
1609  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1610  indexReplacementVals[dim] =
1611  rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1612  newIndexVal =
1613  rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1614  }
1615  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1616  }
1617 
1618  for (auto indexOp : indexOps) {
1619  auto dim = indexOp.getDim();
1620  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1621  }
1622 }
1623 
1625  const CollapsingInfo &collapsingInfo,
1626  RewriterBase &rewriter,
1627  SmallVectorImpl<Value> &inputOperands,
1628  SmallVectorImpl<Value> &outputOperands,
1629  SmallVectorImpl<Type> &resultTypes) {
1630  Location loc = op->getLoc();
1631  inputOperands =
1632  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1633  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1634  rewriter);
1635  });
1636 
1637  // Get the output operands and result types.
1638  resultTypes.reserve(op.getNumDpsInits());
1639  outputOperands.reserve(op.getNumDpsInits());
1640  for (OpOperand &output : op.getDpsInitsMutable()) {
1641  Value newOutput =
1642  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1643  outputOperands.push_back(newOutput);
1644  // If the op has "buffer semantics", then the init operands are ranked
1645  // memrefs and the op has no results.
1646  if (!op.hasPureBufferSemantics())
1647  resultTypes.push_back(newOutput.getType());
1648  }
1649 }
1650 
1651 /// Clone a `LinalgOp` to a collapsed version of same name
1652 template <typename OpTy>
1653 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1654  const CollapsingInfo &collapsingInfo) {
1655  return nullptr;
1656 }
1657 
1658 /// Collapse any `LinalgOp` that does not require any specialization such as
1659 /// indexing_maps, iterator_types, etc.
1660 template <>
1661 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1662  const CollapsingInfo &collapsingInfo) {
1663  SmallVector<Value> inputOperands, outputOperands;
1664  SmallVector<Type> resultTypes;
1665  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1666  outputOperands, resultTypes);
1667 
1668  return clone(
1669  rewriter, origOp, resultTypes,
1670  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1671 }
1672 
1673 /// Collapse a `GenericOp`
1674 template <>
1676  GenericOp origOp,
1677  const CollapsingInfo &collapsingInfo) {
1678  SmallVector<Value> inputOperands, outputOperands;
1679  SmallVector<Type> resultTypes;
1680  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1681  outputOperands, resultTypes);
1682  SmallVector<AffineMap> indexingMaps(
1683  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1684  return getCollapsedOpIndexingMap(map, collapsingInfo);
1685  }));
1686 
1688  origOp.getIteratorTypesArray(), collapsingInfo));
1689 
1690  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1691  origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1692  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1693  Block *origOpBlock = &origOp->getRegion(0).front();
1694  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1695  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1696  collapsedOpBlock->getArguments());
1697  return collapsedOp;
1698 }
1699 
1700 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1701  RewriterBase &rewriter) {
1702  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1703  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1704  } else {
1705  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1706  }
1707 }
1708 
1709 /// Implementation of fusion with reshape operation by collapsing dimensions.
1710 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1711  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1712  RewriterBase &rewriter) {
1713  // Bail on trivial no-op cases.
1714  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1715  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1716  return foldedDims.size() <= 1;
1717  }))
1718  return failure();
1719 
1720  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1721  if (hasPureBufferSemantics &&
1722  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1723  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1724  if (!memRefToCollapse)
1725  return true;
1726 
1727  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1728  memRefToCollapse, foldedIterationDims);
1729  }))
1730  return rewriter.notifyMatchFailure(op,
1731  "memref is not guaranteed collapsible");
1732 
1733  CollapsingInfo collapsingInfo;
1734  if (failed(
1735  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1736  return rewriter.notifyMatchFailure(
1737  op, "illegal to collapse specified dimensions");
1738  }
1739 
1740  // Bail on non-canonical ranges.
1741  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1742  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1743  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1744  return cast<IntegerAttr>(attr).getInt() == value;
1745  llvm::APInt actual;
1746  return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1747  actual.getSExtValue() == value;
1748  };
1749  if (!llvm::all_of(loopRanges, [&](Range range) {
1750  return opFoldIsConstantValue(range.offset, 0) &&
1751  opFoldIsConstantValue(range.stride, 1);
1752  })) {
1753  return rewriter.notifyMatchFailure(
1754  op, "expected all loop ranges to have zero start and unit stride");
1755  }
1756 
1757  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1758 
1759  Location loc = op->getLoc();
1760  SmallVector<OpFoldResult> loopBound =
1761  llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1762 
1763  if (collapsedOp.hasIndexSemantics()) {
1764  // Collect the loop range of the generic op.
1765  OpBuilder::InsertionGuard g(rewriter);
1766  rewriter.setInsertionPoint(collapsedOp);
1767  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1768  collapsingInfo, loopBound, rewriter);
1769  }
1770 
1771  // Insert expanding reshape for the result to get back the original result
1772  // type.
1773  SmallVector<Value> results;
1774  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1775  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1776  auto originalResultType =
1777  cast<ShapedType>(originalResult.value().getType());
1778  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1779  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1780  AffineMap indexingMap =
1781  op.getIndexingMapMatchingResult(originalResult.value());
1782  SmallVector<ReassociationIndices> reassociation =
1783  getOperandReassociation(indexingMap, collapsingInfo);
1784  assert(
1785  indexingMap.isProjectedPermutation() &&
1786  "Expected indexing map to be a projected permutation for collapsing");
1787  SmallVector<OpFoldResult> resultShape =
1788  applyPermutationMap(indexingMap, ArrayRef(loopBound));
1789  Value result;
1790  if (isa<MemRefType>(collapsedOpResult.getType())) {
1791  MemRefType expandShapeResultType = MemRefType::get(
1792  originalResultType.getShape(), originalResultType.getElementType());
1793  result = rewriter.create<memref::ExpandShapeOp>(
1794  loc, expandShapeResultType, collapsedOpResult, reassociation,
1795  resultShape);
1796  } else {
1797  result = rewriter.create<tensor::ExpandShapeOp>(
1798  loc, originalResultType, collapsedOpResult, reassociation,
1799  resultShape);
1800  }
1801  results.push_back(result);
1802  } else {
1803  results.push_back(collapsedOpResult);
1804  }
1805  }
1806  return CollapseResult{results, collapsedOp};
1807 }
1808 
1809 namespace {
1810 
1811 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1812 /// contracting dimensions of the loop.
1813 class FoldWithProducerReshapeOpByCollapsing
1814  : public OpRewritePattern<GenericOp> {
1815 public:
1816  // TODO : support fusion with all linalg ops, not just generic.
1817  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1818  ControlFusionFn foldReshapes,
1819  PatternBenefit benefit = 1)
1820  : OpRewritePattern<GenericOp>(context, benefit),
1821  controlFoldingReshapes(std::move(foldReshapes)) {}
1822 
1823  LogicalResult matchAndRewrite(GenericOp genericOp,
1824  PatternRewriter &rewriter) const override {
1825  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1826  tensor::ExpandShapeOp reshapeOp =
1827  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1828  if (!reshapeOp)
1829  continue;
1830 
1831  SmallVector<ReassociationIndices> collapsableIterationDims =
1832  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1833  reshapeOp.getReassociationIndices());
1834  if (collapsableIterationDims.empty() ||
1835  !controlFoldingReshapes(&opOperand)) {
1836  continue;
1837  }
1838 
1839  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1840  genericOp, collapsableIterationDims, rewriter);
1841  if (!collapseResult) {
1842  return rewriter.notifyMatchFailure(
1843  genericOp, "failed to do the fusion by collapsing transformation");
1844  }
1845 
1846  rewriter.replaceOp(genericOp, collapseResult->results);
1847  return success();
1848  }
1849  return failure();
1850  }
1851 
1852 private:
1853  ControlFusionFn controlFoldingReshapes;
1854 };
1855 
1856 /// Pattern to fold a tensor.collapse_shape op with its producer generic op
1857 /// by expanding the dimensionality of the loop in the producer op.
1858 struct FoldReshapeWithGenericOpByCollapsing
1859  : public OpRewritePattern<tensor::CollapseShapeOp> {
1860 
1861  FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1862  ControlFusionFn foldReshapes,
1863  PatternBenefit benefit = 1)
1864  : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1865  controlFoldingReshapes(std::move(foldReshapes)) {}
1866 
1867  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1868  PatternRewriter &rewriter) const override {
1869  // Fold only if all constraints of fusing with reshape by collapsing are
1870  // met.
1871  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1872  if (!producerResult) {
1873  return rewriter.notifyMatchFailure(reshapeOp,
1874  "source not produced by an operation");
1875  }
1876 
1877  // TODO : support fusion with all linalg producers, not just generic.
1878  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1879  if (!producer) {
1880  return rewriter.notifyMatchFailure(reshapeOp,
1881  "producer not a generic op");
1882  }
1883 
1884  SmallVector<ReassociationIndices> collapsableIterationDims =
1886  producer,
1887  producer.getDpsInitOperand(producerResult.getResultNumber()),
1888  reshapeOp.getReassociationIndices());
1889  if (collapsableIterationDims.empty()) {
1890  return rewriter.notifyMatchFailure(
1891  reshapeOp, "failed preconditions of fusion with producer generic op");
1892  }
1893 
1894  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1895  return rewriter.notifyMatchFailure(reshapeOp,
1896  "fusion blocked by control function");
1897  }
1898 
1899  // Set the insertion point after `producer` because there could be uses
1900  // of `producer` between it and the `tensor.collapse_shape` op.
1901  rewriter.setInsertionPointAfter(producer);
1902  std::optional<CollapseResult> collapseResult =
1903  collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1904  if (!collapseResult) {
1905  return rewriter.notifyMatchFailure(
1906  producer, "failed to do the fusion by collapsing transformation");
1907  }
1908 
1909  rewriter.replaceOp(producer, collapseResult->results);
1910  return success();
1911  }
1912 
1913 private:
1914  ControlFusionFn controlFoldingReshapes;
1915 };
1916 
1917 class FoldPadWithProducerReshapeOpByCollapsing
1918  : public OpRewritePattern<tensor::PadOp> {
1919 public:
1920  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1921  ControlFusionFn foldReshapes,
1922  PatternBenefit benefit = 1)
1923  : OpRewritePattern<tensor::PadOp>(context, benefit),
1924  controlFoldingReshapes(std::move(foldReshapes)) {}
1925 
1926  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1927  PatternRewriter &rewriter) const override {
1928  tensor::ExpandShapeOp reshapeOp =
1929  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1930  if (!reshapeOp)
1931  return failure();
1932  if (!reshapeOp->hasOneUse())
1933  return failure();
1934 
1935  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1936  return rewriter.notifyMatchFailure(padOp,
1937  "fusion blocked by control function");
1938  }
1939 
1940  ArrayRef<int64_t> low = padOp.getStaticLow();
1941  ArrayRef<int64_t> high = padOp.getStaticHigh();
1942  SmallVector<ReassociationIndices> reassociations =
1943  reshapeOp.getReassociationIndices();
1944 
1945  for (auto reInd : reassociations) {
1946  if (reInd.size() == 1)
1947  continue;
1948  if (llvm::any_of(reInd, [&](int64_t ind) {
1949  return low[ind] != 0 || high[ind] != 0;
1950  })) {
1951  return failure();
1952  }
1953  }
1954 
1955  SmallVector<OpFoldResult> newLow, newHigh;
1956  RankedTensorType collapsedType = reshapeOp.getSrcType();
1957  RankedTensorType paddedType = padOp.getResultType();
1958  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1959  SmallVector<OpFoldResult> expandedPaddedSizes(
1960  getMixedValues(reshapeOp.getStaticOutputShape(),
1961  reshapeOp.getOutputShape(), rewriter));
1962  AffineExpr d0, d1, d2;
1963  bindDims(rewriter.getContext(), d0, d1, d2);
1964  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1965  Location loc = reshapeOp->getLoc();
1966  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1967  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1968  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1969  if (reInd.size() == 1) {
1970  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1972  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1973  expandedPaddedSizes[reInd[0]] = paddedSize;
1974  }
1975  newLow.push_back(l);
1976  newHigh.push_back(h);
1977  }
1978 
1979  RankedTensorType collapsedPaddedType =
1980  paddedType.clone(collapsedPaddedShape);
1981  auto newPadOp = rewriter.create<tensor::PadOp>(
1982  loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1983  padOp.getConstantPaddingValue(), padOp.getNofold());
1984 
1985  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1986  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1987  expandedPaddedSizes);
1988 
1989  return success();
1990  }
1991 
1992 private:
1993  ControlFusionFn controlFoldingReshapes;
1994 };
1995 
1996 /// Pattern to collapse dimensions.
1997 template <typename LinalgType>
1998 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1999 public:
2000  CollapseLinalgDimensions(MLIRContext *context,
2001  GetCollapsableDimensionsFn collapseDimensions,
2002  PatternBenefit benefit = 1)
2003  : OpRewritePattern<LinalgType>(context, benefit),
2004  controlCollapseDimension(std::move(collapseDimensions)) {}
2005 
2006  LogicalResult matchAndRewrite(LinalgType op,
2007  PatternRewriter &rewriter) const override {
2008  SmallVector<ReassociationIndices> collapsableIterationDims =
2009  controlCollapseDimension(op);
2010  if (collapsableIterationDims.empty())
2011  return failure();
2012 
2013  // Check if the specified list of dimensions to collapse is a valid list.
2014  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2015  collapsableIterationDims)) {
2016  return rewriter.notifyMatchFailure(
2017  op, "specified dimensions cannot be collapsed");
2018  }
2019 
2020  std::optional<CollapseResult> collapseResult =
2021  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2022  if (!collapseResult) {
2023  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2024  }
2025  rewriter.replaceOp(op, collapseResult->results);
2026  return success();
2027  }
2028 
2029 private:
2030  GetCollapsableDimensionsFn controlCollapseDimension;
2031 };
2032 
2033 } // namespace
2034 
2035 //===---------------------------------------------------------------------===//
2036 // Methods and patterns that fuse constants with linalg.generic operations.
2037 //===---------------------------------------------------------------------===//
2038 
2039 namespace {
2040 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2041 /// handle cases where the constant is not single-valued.
2042 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2043 public:
2044  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2045  : OpRewritePattern<GenericOp>(context, benefit) {}
2046 
2047  LogicalResult matchAndRewrite(GenericOp genericOp,
2048  PatternRewriter &rewriter) const override {
2049  if (!genericOp.hasPureTensorSemantics())
2050  return failure();
2051  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2052  Operation *def = opOperand->get().getDefiningOp();
2053  TypedAttr constantAttr;
2054  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2055  {
2056  DenseElementsAttr splatAttr;
2057  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2058  splatAttr.isSplat() &&
2059  splatAttr.getType().getElementType().isIntOrFloat()) {
2060  constantAttr = splatAttr.getSplatValue<TypedAttr>();
2061  return true;
2062  }
2063  }
2064  {
2065  IntegerAttr intAttr;
2066  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2067  constantAttr = intAttr;
2068  return true;
2069  }
2070  }
2071  {
2072  FloatAttr floatAttr;
2073  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2074  constantAttr = floatAttr;
2075  return true;
2076  }
2077  }
2078  return false;
2079  };
2080 
2081  auto resultValue = dyn_cast<OpResult>(opOperand->get());
2082  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2083  continue;
2084 
2085  // The operands and the indexing_maps of the fused operation the same as
2086  // the operands and indexing_maps of the generic operations with the
2087  // values at the constant index dropped.
2088  SmallVector<AffineMap> fusedIndexMaps;
2089  SmallVector<Value> fusedOperands;
2090  SmallVector<Location> fusedLocs{genericOp.getLoc()};
2091  fusedIndexMaps.reserve(genericOp->getNumOperands());
2092  fusedOperands.reserve(genericOp.getNumDpsInputs());
2093  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2094  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2095  if (inputOperand == opOperand)
2096  continue;
2097  Value inputValue = inputOperand->get();
2098  fusedIndexMaps.push_back(
2099  genericOp.getMatchingIndexingMap(inputOperand));
2100  fusedOperands.push_back(inputValue);
2101  fusedLocs.push_back(inputValue.getLoc());
2102  }
2103  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2104  fusedIndexMaps.push_back(
2105  genericOp.getMatchingIndexingMap(&outputOperand));
2106 
2107  // Check if the operation shapes to loops map is computable.
2108  if (!inversePermutation(
2109  concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2110  return rewriter.notifyMatchFailure(
2111  genericOp, "fused op loop bound computation failed");
2112  }
2113 
2114  // Create a constant scalar value from the splat constant.
2115  Value scalarConstant =
2116  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2117 
2118  SmallVector<Value> outputOperands = genericOp.getOutputs();
2119  auto fusedOp = rewriter.create<GenericOp>(
2120  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2121  /*inputs=*/fusedOperands,
2122  /*outputs=*/outputOperands,
2123  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2124  genericOp.getIteratorTypes(),
2125  /*doc=*/nullptr,
2126  /*library_call=*/nullptr);
2127 
2128  // Map the block argument corresponding to the replaced argument with the
2129  // scalar constant.
2130  Region &region = genericOp->getRegion(0);
2131  Block &entryBlock = *region.begin();
2132  IRMapping mapping;
2133  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2134  scalarConstant);
2135  Region &fusedRegion = fusedOp->getRegion(0);
2136  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2137  mapping);
2138  rewriter.replaceOp(genericOp, fusedOp->getResults());
2139  return success();
2140  }
2141  return failure();
2142  }
2143 };
2144 
2145 } // namespace
2146 
2147 //===---------------------------------------------------------------------===//
2148 // Miscellaneous patterns that help fusion.
2149 //===---------------------------------------------------------------------===//
2150 
2151 namespace {
2152 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2153 /// value of the `outs` operand is not used within the op. This is only
2154 /// implemented for `linalg.generic` operations for now, but should hold for all
2155 /// linalg structured ops.
2156 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2158 
2159  LogicalResult matchAndRewrite(GenericOp op,
2160  PatternRewriter &rewriter) const override {
2161  rewriter.startOpModification(op);
2162  bool modifiedOutput = false;
2163  Location loc = op.getLoc();
2164  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2165  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2166  Value operandVal = opOperand.get();
2167  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2168  if (!operandType)
2169  continue;
2170 
2171  // If outs is sparse, leave it to the sparsifier.
2173  continue;
2174 
2175  // If outs is already an `empty` operation, nothing to do.
2176  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2177  if (definingOp)
2178  continue;
2179  modifiedOutput = true;
2180  SmallVector<OpFoldResult> mixedSizes =
2181  tensor::getMixedSizes(rewriter, loc, operandVal);
2182  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2183  loc, mixedSizes, operandType.getElementType());
2184  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2185  }
2186  }
2187  if (!modifiedOutput) {
2188  rewriter.cancelOpModification(op);
2189  return failure();
2190  }
2191  rewriter.finalizeOpModification(op);
2192  return success();
2193  }
2194 };
2195 
2196 /// Fold linalg.fill into linalg.generic
2197 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2199 
2200  LogicalResult matchAndRewrite(GenericOp genericOp,
2201  PatternRewriter &rewriter) const override {
2202  if (!genericOp.hasPureTensorSemantics())
2203  return failure();
2204  bool fillFound = false;
2205  Block &payload = genericOp.getRegion().front();
2206  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2207  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2208  continue;
2209  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2210  if (!fillOp)
2211  continue;
2212  fillFound = true;
2213  Value fillVal = fillOp.value();
2214  auto resultType =
2215  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2216  Value convertedVal =
2217  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2218  /*isUnsignedCast =*/false);
2219  rewriter.replaceAllUsesWith(
2220  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2221  }
2222  return success(fillFound);
2223  }
2224 };
2225 } // namespace
2226 
2229  const ControlFusionFn &controlFoldingReshapes) {
2230  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2231  controlFoldingReshapes);
2232  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2233  controlFoldingReshapes);
2234  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2235  controlFoldingReshapes);
2236 }
2237 
2240  const ControlFusionFn &controlFoldingReshapes) {
2241  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2242  controlFoldingReshapes);
2243  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2244  patterns.getContext(), controlFoldingReshapes);
2245  patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2246  controlFoldingReshapes);
2247 }
2248 
2251  const ControlFusionFn &controlElementwiseOpsFusion) {
2252  auto *context = patterns.getContext();
2253  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2254  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2255  RemoveOutsDependency>(context);
2256  // Add the patterns that clean up dead operands and results.
2258 }
2259 
2262  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2263  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2264  CollapseLinalgDimensions<linalg::CopyOp>>(
2265  patterns.getContext(), controlCollapseDimensions);
2266 }
2267 
2268 //===---------------------------------------------------------------------===//
2269 // Passes
2270 //===---------------------------------------------------------------------===//
2271 
2272 namespace {
2273 
2274 /// Pass that fuses generic ops on tensors. Used only for testing.
2275 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2276 // patterns added here heavily depends on the cost function used. Having an
2277 // opinionated pass of this form is not recommended. Deprecate this pass in
2278 // favor of test passes that check the functionality of each of the patterns
2279 // added here individually.
2280 struct LinalgElementwiseOpFusionPass
2281  : public impl::LinalgElementwiseOpFusionPassBase<
2282  LinalgElementwiseOpFusionPass> {
2283  using impl::LinalgElementwiseOpFusionPassBase<
2284  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2285  void runOnOperation() override {
2286  Operation *op = getOperation();
2287  MLIRContext *context = op->getContext();
2288  RewritePatternSet patterns(context);
2289 
2290  // Add folding with reshape by expansion patterns.
2291  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2292  Operation *producer = fusedOperand->get().getDefiningOp();
2293  return producer && producer->hasOneUse();
2294  };
2295 
2296  // Add elementwise op fusion patterns.
2300 
2301  // General canonicalization patterns.
2302  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2303  GenericOp::getCanonicalizationPatterns(patterns, context);
2304  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2305  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2306  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2307  patterns);
2308 
2309  // Add constant folding patterns.
2311 
2312  // Use TopDownTraversal for compile time reasons.
2313  (void)applyPatternsGreedily(op, std::move(patterns),
2315  }
2316 };
2317 
2318 } // namespace
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:615
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:651
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:645
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:316
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:578
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:587
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:577
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:238
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1867
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:242
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1898
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:833
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:788
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:549
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:551
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.