MLIR  21.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Support/LLVM.h"
29 #include <optional>
30 #include <utility>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 //===---------------------------------------------------------------------===//
41 // Methods and patterns that fuse elementwise `linalg.generic` operations.
42 //===---------------------------------------------------------------------===//
43 
44 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45 /// the `producer` to use in the fused operation given the indexing map of the
46 /// result of the producer in the consumer.
48  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49  AffineMap fusedConsumerArgIndexMap) {
50  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51  // from consumer loop -> consumer arg tensor index/producer result tensor
52  // index. The fused loop is same as the consumer loop. For each producer arg
53  // the indexing map to be computed is a map from consumer loop -> producer
54  // arg tensor index.
55  // producerResultIndexMap is a map from producer loop -> tensor index.
56  // Compute the inverse to get map from tensor index -> producer loop.
57  // The inverse is a map from producer result tensor index -> producer loop.
58  AffineMap invProducerResultIndexMap =
59  inversePermutation(producerResultIndexMap);
60  assert(invProducerResultIndexMap &&
61  "expected producer result indexing map to be invertible");
62 
63  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64  // argMap is a map from producer loop -> producer arg tensor index.
65  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66 
67  // Compose argMap with invProducerResultIndexMap to get a map from
68  // producer result tensor index -> producer arg tensor index.
69  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70 
71  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72  // consumer loop/ fused loop -> producer arg tensor index.
73  return t1.compose(fusedConsumerArgIndexMap);
74 }
75 
76 // Checks if the given operand can be dropped, and the remaining operands
77 // of the fused producer & consumer after the fusion can still compute the
78 // bounds of the op.
80  GenericOp producer, GenericOp consumer,
81  ArrayRef<OpOperand *> opOperandsToIgnore) {
82  SmallVector<AffineMap> indexingMaps;
83 
84  SmallVector<GenericOp> ops = {producer, consumer};
85  for (auto &op : ops) {
86  for (auto &opOperand : op->getOpOperands()) {
87  if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88  continue;
89  }
90  indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91  }
92  }
93  if (indexingMaps.empty()) {
94  // If there are no indexing maps, the operand can only be dropped
95  // if neither op has loops.
96  return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97  }
98 
99  // The concatanation of the remained indexing maps must be invertible, so
100  // the bounds of the op can be still computed after dropping the selected
101  // operand. inversePermutation returns an empty AffineMap in case the
102  // concatanated indexing maps are not invertible.
104  indexingMaps, producer.getContext())) != AffineMap();
105 }
106 
107 /// Returns a set of indices of the producer's results which would
108 /// be preserved after the fusion.
109 /// * There is a chance that the implementation of the transformation does not
110 /// agree with the result of this method. This function gives a prediction based
111 /// on an optimized fusion.
113  GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114  llvm::SmallDenseSet<int> preservedProducerResults;
115  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116 
117  // The fusedOperand will be removed during the fusion
118  opOperandsToIgnore.emplace_back(fusedOperand);
119 
120  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122  opOperandsToIgnore.emplace_back(outputOperand);
123  if (producer.payloadUsesValueFromOperand(outputOperand) ||
124  !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
125  opOperandsToIgnore) ||
126  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127  return user != consumer.getOperation();
128  })) {
129  preservedProducerResults.insert(producerResult.index());
130 
131  // In case the operand can't be dropped
132  (void)opOperandsToIgnore.pop_back_val();
133  }
134  }
135  return preservedProducerResults;
136 }
137 
138 /// Conditions for elementwise fusion of generic operations.
140  if (!fusedOperand)
141  return false;
142 
143  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145 
146  // Check producer and consumer are generic ops.
147  if (!producer || !consumer)
148  return false;
149 
150  // Consumer can have mixed semantics, just check operand itself has tensor
151  // type. Producer must have full tensor semantics to avoid potential
152  // aliasing between producer and consumer memrefs.
153  if (!producer.hasPureTensorSemantics() ||
154  !isa<RankedTensorType>(fusedOperand->get().getType()))
155  return false;
156 
157  // Verify that
158  // - the producer has all "parallel" iterator type.
159  if (producer.getNumParallelLoops() != producer.getNumLoops())
160  return false;
161 
162  // Only allow fusing the producer of an input operand for now.
163  // TODO: allow fusing the producer of an output operand.
164  if (!consumer.isDpsInput(fusedOperand))
165  return false;
166 
167  // Get the consumer index map. The number of results of the consumer index
168  // map must match the number of loops of the producer.
169  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171  return false;
172 
173  // Finally the index_map for the result must be invertible. For now just
174  // verify it is a permutation.
175  AffineMap producerResultIndexMap =
176  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
177  if (!producerResultIndexMap.isPermutation())
178  return false;
179 
180  // Ensure that the fusion does not remove size information required to
181  // get the loop bounds. For non-reduction generics, this is trivially the
182  // case due to the output operand. For reductions, we need to check that after
183  // the fusion, each loop dimension has at least one input that defines it.
184  if ((consumer.getNumReductionLoops())) {
185  BitVector coveredDims(consumer.getNumLoops(), false);
186 
187  auto addToCoveredDims = [&](AffineMap map) {
188  for (auto result : map.getResults())
189  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
190  coveredDims[dimExpr.getPosition()] = true;
191  };
192 
193  for (auto pair :
194  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195  Value operand = std::get<0>(pair);
196  if (operand == fusedOperand->get())
197  continue;
198  AffineMap operandMap = std::get<1>(pair);
199  addToCoveredDims(operandMap);
200  }
201 
202  for (OpOperand *operand : producer.getDpsInputOperands()) {
203  AffineMap newIndexingMap =
205  operand, producerResultIndexMap, consumerIndexMap);
206  addToCoveredDims(newIndexingMap);
207  }
208  if (!coveredDims.all())
209  return false;
210  }
211 
212  return true;
213 }
214 
215 /// Generate the region of the fused tensor operation. The region of the fused
216 /// op must be empty.
218  RewriterBase &rewriter, GenericOp fusedOp,
219  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
220  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
221  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
222  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
223  // Build the region of the fused op.
224  Block &producerBlock = producer->getRegion(0).front();
225  Block &consumerBlock = consumer->getRegion(0).front();
226  OpBuilder::InsertionGuard guard(rewriter);
227  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
228  IRMapping mapper;
229 
230  // 2. Add an index operation for every fused loop dimension and use the
231  // `consumerToProducerLoopsMap` to map the producer indices.
232  if (producer.hasIndexSemantics()) {
233  // Add an index operation for every fused loop dimension.
234  unsigned numFusedOpLoops =
235  std::max(producer.getNumLoops(), consumer.getNumLoops());
236  SmallVector<Value> fusedIndices;
237  fusedIndices.reserve(numFusedOpLoops);
238  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239  std::back_inserter(fusedIndices), [&](uint64_t dim) {
240  return rewriter.create<IndexOp>(producer.getLoc(), dim);
241  });
242  for (IndexOp indexOp :
243  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244  Value newIndex = rewriter.create<affine::AffineApplyOp>(
245  producer.getLoc(),
246  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247  mapper.map(indexOp.getResult(), newIndex);
248  }
249  }
250  // TODO: allow fusing the producer of an output operand.
251  assert(consumer.isDpsInput(fusedOperand) &&
252  "expected producer of input operand");
253  // 3. Consumer input operands up to consumerIdx (exclusive).
254  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255  fusedOperand->getOperandNumber())) // input assumption.
256  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257 
258  // Replacing consumerIdx requires getting the cloned, yielded, value from
259  // the (cloned) producer block. This happens in step 9.
260 
261  // 4. Splice in producer's input operands.
262  for (BlockArgument bbArg :
263  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265 
266  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267  for (BlockArgument bbArg :
268  consumerBlock.getArguments()
269  .take_front(consumer.getNumDpsInputs())
270  .drop_front(fusedOperand->getOperandNumber() + 1))
271  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272 
273  // 6. All of the producer's output operands
274  for (const auto &bbArg : llvm::enumerate(
275  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276  if (!preservedProducerResults.count(bbArg.index()))
277  continue;
278  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279  bbArg.value().getLoc()));
280  }
281 
282  // 7. All of consumer's output operands.
283  for (BlockArgument bbArg :
284  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286 
287  // 8. Clone all producer operations except for the yield and index operations
288  // to the fused operation.
289  for (auto &op : producerBlock.without_terminator()) {
290  if (!isa<IndexOp>(op))
291  rewriter.clone(op, mapper);
292  }
293  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294  // forward the yield operand.
295  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296  unsigned producerResultNumber =
297  cast<OpResult>(fusedOperand->get()).getResultNumber();
298  Value replacement =
299  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300 
301  // Sanity checks, if replacement is not already in the mapper then it must be
302  // produced outside.
303  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304  if (auto bb = dyn_cast<BlockArgument>(replacement))
305  assert(bb.getOwner() != &producerBlock &&
306  "yielded block argument must have been mapped");
307  else
308  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309  "yielded value must have been mapped");
310  }
311  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
312  replacement);
313  // 10. Clone operations from the consumer to the fused op.
314  for (auto &op : consumerBlock.without_terminator())
315  rewriter.clone(op, mapper);
316 
317  // 11. Include the final yield (which is the remapped values for all the
318  // yield)
319  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320  SmallVector<Value> fusedYieldValues;
321  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322  consumerYieldOp.getNumOperands());
323  for (const auto &producerYieldVal :
324  llvm::enumerate(producerYieldOp.getOperands())) {
325  if (preservedProducerResults.count(producerYieldVal.index()))
326  fusedYieldValues.push_back(
327  mapper.lookupOrDefault(producerYieldVal.value()));
328  }
329  for (auto consumerYieldVal : consumerYieldOp.getOperands())
330  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
332 
333  // Sanity checks.
334  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335  "Ill-formed GenericOp region");
336 }
337 
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340  OpOperand *fusedOperand) {
341  assert(areElementwiseOpsFusable(fusedOperand) &&
342  "expected elementwise operation pre-conditions to pass");
343  auto producerResult = cast<OpResult>(fusedOperand->get());
344  auto producer = cast<GenericOp>(producerResult.getOwner());
345  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346  // TODO: allow fusing the producer of an output operand.
347  assert(consumer.isDpsInput(fusedOperand) &&
348  "expected producer of input operand");
349  /// Find the results of the producer that have uses outside of the consumer,
350  /// after the fusion.
351  llvm::SmallDenseSet<int> preservedProducerResults =
353  fusedOperand);
354 
355  // Compute the fused operands list and indexing maps.
356  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357  SmallVector<Type> fusedResultTypes;
358  SmallVector<AffineMap> fusedIndexMaps;
359  fusedInputOperands.reserve(producer.getNumDpsInputs() +
360  consumer.getNumDpsInputs());
361  fusedOutputOperands.reserve(preservedProducerResults.size() +
362  consumer.getNumDpsInits());
363  fusedResultTypes.reserve(preservedProducerResults.size() +
364  consumer.getNumDpsInits());
365  fusedIndexMaps.reserve(producer->getNumOperands() +
366  consumer->getNumOperands());
367  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369  auto consumerInputs = consumer.getDpsInputOperands();
370  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371  return operand == fusedOperand;
372  });
373  assert(it != consumerInputs.end() && "expected to find the consumer operand");
374  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375  fusedInputOperands.push_back(opOperand->get());
376  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377  }
378  // 4. Splice in producer's input operands/maps.
379  AffineMap producerResultIndexMap =
380  producer.getIndexingMapMatchingResult(producerResult);
381  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382  fusedInputOperands.push_back(opOperand->get());
383  // Compute indexing maps for the producer args in the fused operation.
385  opOperand, producerResultIndexMap,
386  consumer.getMatchingIndexingMap(fusedOperand));
387  fusedIndexMaps.push_back(map);
388  }
389  // 5. Remaining consumer's input operands/maps (drop past index
390  // `consumerIdx`).
391  for (OpOperand *opOperand :
392  llvm::make_range(std::next(it), consumerInputs.end())) {
393  fusedInputOperands.push_back(opOperand->get());
394  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395  }
396 
397  // 6. Collect all of the producer outputs.
398  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399  if (!preservedProducerResults.count(opOperand.index()))
400  continue;
401 
402  fusedOutputOperands.push_back(opOperand.value().get());
404  &opOperand.value(), producerResultIndexMap,
405  consumer.getMatchingIndexingMap(fusedOperand));
406  fusedIndexMaps.push_back(map);
407  fusedResultTypes.push_back(opOperand.value().get().getType());
408  }
409 
410  // 7. All of consumer's output operands (skip operands: added by the builder).
411  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412  fusedOutputOperands.push_back(opOperand.get());
413  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414  Type resultType = opOperand.get().getType();
415  if (!isa<MemRefType>(resultType))
416  fusedResultTypes.push_back(resultType);
417  }
418 
419  // Generate the fused op.
420  auto fusedOp = rewriter.create<GenericOp>(
421  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423  consumer.getIteratorTypes(),
424  /*doc=*/nullptr,
425  /*library_call=*/nullptr);
426  if (!fusedOp.getShapesToLoopsMap()) {
427  // Fused op has invalid indexing maps. Typically this means something is off
428  // in the input, but going ahead here would result in verification errors.
429  // So cleanup and abort.
430  rewriter.eraseOp(fusedOp);
431  return rewriter.notifyMatchFailure(
432  fusedOp, "fused op failed loop bound computation check");
433  }
434 
435  // Construct an AffineMap from consumer loops to producer loops.
436  // consumer loop -> tensor index
437  AffineMap consumerResultIndexMap =
438  consumer.getMatchingIndexingMap(fusedOperand);
439  // tensor index -> producer loop
440  AffineMap invProducerResultIndexMap =
441  inversePermutation(producerResultIndexMap);
442  assert(invProducerResultIndexMap &&
443  "expected producer result indexig map to be invertible");
444  // consumer loop -> producer loop
445  AffineMap consumerToProducerLoopsMap =
446  invProducerResultIndexMap.compose(consumerResultIndexMap);
447 
449  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450  consumer.getNumLoops(), preservedProducerResults);
452  result.fusedOp = fusedOp;
453  int resultNum = 0;
454  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455  if (preservedProducerResults.count(index))
456  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457  for (auto consumerResult : consumer->getResults())
458  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459  return result;
460 }
461 
462 namespace {
463 /// Patterns to fuse a generic op, with the producer of its operands.
464 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465 public:
466  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467  PatternBenefit benefit = 1)
468  : OpRewritePattern<GenericOp>(context, benefit),
469  controlFn(std::move(fun)) {}
470 
471  LogicalResult matchAndRewrite(GenericOp genericOp,
472  PatternRewriter &rewriter) const override {
473  // Find the first operand that is defined by another generic op on tensors.
474  for (OpOperand &opOperand : genericOp->getOpOperands()) {
475  if (!areElementwiseOpsFusable(&opOperand))
476  continue;
477  if (!controlFn(&opOperand))
478  continue;
479 
480  Operation *producer = opOperand.get().getDefiningOp();
481 
482  // Find the producer of the operand.
483  FailureOr<ElementwiseOpFusionResult> fusionResult =
484  fuseElementwiseOps(rewriter, &opOperand);
485  if (failed(fusionResult))
486  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487 
488  // Perform the fusion.
489  for (auto [origVal, replacement] : fusionResult->replacements) {
490  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491  // Only replace consumer uses.
492  return use.get().getDefiningOp() != producer;
493  });
494  }
495  rewriter.eraseOp(genericOp);
496  return success();
497  }
498  return failure();
499  }
500 
501 private:
502  ControlFusionFn controlFn;
503 };
504 } // namespace
505 
506 //===---------------------------------------------------------------------===//
507 // Methods and patterns that fuse reshape ops with elementwise operations by
508 // expanding the dimensionality of the elementwise operations.
509 //===---------------------------------------------------------------------===//
510 
511 /// Conditions for folding a structured linalg operation with a reshape op by
512 /// expanding the iteration space dimensionality for tensor operations. These
513 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514 /// the following fusion pattern.
515 ///
516 /// Consider
517 ///
518 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
521 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524 ///
525 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
526 /// is increased to match the result (operand) of the tensor.expand_shape.
527 /// The indexing_map of the fused tensor in the `linalgOp` and the
528 /// reassociation map helps compute the indexing maps of the modified op.
529 /// For the above example, based on the reassociation map it
530 /// can be concluded that
531 ///
532 /// - The loop used to access the first dimension of the fused tensor is split
533 /// into two.
534 /// - The loop used to access the second dimension of the fused tensor is kept
535 /// as is.
536 /// - The loop used to access the third dimension of the fused tensor is split
537 /// into three.
538 ///
539 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540 /// op, then
541 ///
542 /// d0 -> e0, e1
543 /// d1 -> e2, e3, e4
544 /// d2 -> e5
545 ///
546 /// substituting this, the structured op can be rewritten as
547 ///
548 /// %d = linalg.generic ins(%0, %1 : )
549 /// indexing_maps =
550 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553 ///
554 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
555 /// to make it consistent
556 ///
557 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561 ///
562 /// The added reshapes are again expanding patterns, so they will get fused
563 /// with its producers if possible.
564 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565  OpOperand *fusableOpOperand) {
566  // Is fusable only if:
567  // - All the indexing maps for operands and results are projected
568  // permutations.
569  // - The fused tensor is not a scalar.
570  SmallVector<utils::IteratorType> iteratorTypes =
571  linalgOp.getIteratorTypesArray();
572  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573  return linalgOp.hasPureTensorSemantics() &&
574  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575  [](Attribute attr) {
576  return cast<AffineMapAttr>(attr)
577  .getValue()
578  .isProjectedPermutation();
579  }) &&
580  operandMap.getNumResults() > 0;
581 }
582 
583 namespace {
584 /// Information needed to expand a generic operation to fold the reshape with
585 /// it.
586 class ExpansionInfo {
587 public:
588  // Computes the mapping from original dimensions of the op to the dimensions
589  // of the expanded op given the `indexingMap` of the fused operand/result of
590  // the generic op, the `reassocationMaps` of the reshape op and the shape of
591  // the expanded op.
592  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593  ArrayRef<AffineMap> reassociationMaps,
594  ArrayRef<OpFoldResult> expandedShape,
595  PatternRewriter &rewriter);
596  unsigned getOrigOpNumDims() const { return reassociation.size(); }
597  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598  ReassociationIndicesRef getExpandedDims(unsigned i) const {
599  return reassociation[i];
600  }
601  ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602  return expandedShapeMap[i];
603  }
604  ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605 
606 private:
607  /// Reassociation from the dimensions in the original operation to the
608  /// dimension of the expanded operation.
609  SmallVector<ReassociationIndices> reassociation;
610  /// Mapping from extent of loops in the original operation, to the extent of
611  /// loops in the expanded operation.
612  SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613  /// Extent of the loop in the original operation.
614  SmallVector<OpFoldResult> originalLoopExtent;
615  unsigned expandedOpNumDims;
616 };
617 } // namespace
618 
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620  OpOperand *fusableOpOperand,
621  ArrayRef<AffineMap> reassociationMaps,
622  ArrayRef<OpFoldResult> expandedShape,
623  PatternRewriter &rewriter) {
624  if (reassociationMaps.empty())
625  return failure();
626  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627 
628  OpBuilder::InsertionGuard g(rewriter);
629  rewriter.setInsertionPoint(linalgOp);
630  originalLoopExtent = llvm::map_to_vector(
631  linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632  [](Range r) { return r.size; });
633 
634  reassociation.clear();
635  expandedShapeMap.clear();
636  // Compute the number of dimension in the expanded op that correspond to each
637  // dimension of the original op.
638  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639  expandedShapeMap.resize(fusedIndexMap.getNumDims());
640  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643  numExpandedDims[pos] = foldedDims.getNumResults();
644  ArrayRef<OpFoldResult> shape =
645  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646  expandedShapeMap[pos].assign(shape.begin(), shape.end());
647  }
648  // The remaining dimensions remain the same.
649  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650  if (expandedShapeMap[i].empty())
651  expandedShapeMap[i] = {originalLoopExtent[i]};
652 
653  // Compute reassociation map from the original op to the expanded op.
654  unsigned sum = 0;
655  reassociation.reserve(fusedIndexMap.getNumDims());
656  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658  reassociation.emplace_back(seq.begin(), seq.end());
659  sum += numFoldedDim.value();
660  }
661  expandedOpNumDims = sum;
662  return success();
663 }
664 
665 /// Return the indexing map to use in the expanded op for a given the
666 /// `indexingMap` of the original operation.
667 static AffineMap
669  const ExpansionInfo &expansionInfo) {
670  SmallVector<AffineExpr> newExprs;
671  for (AffineExpr expr : indexingMap.getResults()) {
672  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675  return builder.getAffineDimExpr(static_cast<unsigned>(v));
676  }));
677  newExprs.append(expandedExprs.begin(), expandedExprs.end());
678  }
679  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
680  indexingMap.getNumSymbols(), newExprs,
681  builder.getContext());
682 }
683 
684 /// Return the shape and type of the operand/result to use in the expanded op
685 /// given the type in the original op.
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687 getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688  const ExpansionInfo &expansionInfo) {
689  SmallVector<OpFoldResult> expandedShape;
690  for (AffineExpr expr : indexingMap.getResults()) {
691  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692  ArrayRef<OpFoldResult> dimExpansion =
693  expansionInfo.getExpandedShapeOfDim(dim);
694  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
695  }
696  SmallVector<int64_t> expandedStaticShape;
697  std::tie(expandedStaticShape, std::ignore) =
698  decomposeMixedValues(expandedShape);
699  return {expandedShape, RankedTensorType::get(expandedStaticShape,
700  originalType.getElementType())};
701 }
702 
703 /// Returns the reassociation maps to use in the `tensor.expand_shape`
704 /// operation to convert the operands of the original operation to operands of
705 /// the expanded operation. The same method is used to compute the
706 /// `tensor.collapse_shape` used to collapse the result of the expanded
707 /// op to get the value that can replace all uses of the results of the original
708 /// op.
711  const ExpansionInfo &expansionInfo) {
712  SmallVector<ReassociationIndices> reassociation;
713  unsigned numReshapeDims = 0;
714  for (AffineExpr expr : indexingMap.getResults()) {
715  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719  reassociation.emplace_back(std::move(indices));
720  numReshapeDims += numExpandedDims;
721  }
722  return reassociation;
723 }
724 
725 /// Update the body of an expanded linalg operation having index semantics. The
726 /// indices of the original operation need to be recovered by linearizing the
727 /// indices of the correspoding dimensions of the expanded operation. For now it
728 /// is assumed that the shapes of the expanded operation needed for
729 /// linearization are static.
731  Location loc, Region &fusedRegion,
732  const ExpansionInfo &expansionInfo) {
733  // Replace the original indices by the linearization of the expanded indices.
734  for (IndexOp indexOp :
735  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736  ArrayRef<int64_t> expandedDims =
737  expansionInfo.getExpandedDims(indexOp.getDim());
738  assert(!expandedDims.empty() && "expected valid expansion info");
739 
740  // Skip index operations that are not affected by the expansion.
741  if (expandedDims.size() == 1 &&
742  expandedDims.front() == (int64_t)indexOp.getDim())
743  continue;
744 
745  // Linearize the expanded indices of the original index dimension.
746  OpBuilder::InsertionGuard guard(rewriter);
747  rewriter.setInsertionPointAfter(indexOp);
748  ArrayRef<OpFoldResult> expandedDimsShape =
749  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750  SmallVector<Value> expandedIndices;
751  expandedIndices.reserve(expandedDims.size() - 1);
752  llvm::transform(
753  expandedDims.drop_front(), std::back_inserter(expandedIndices),
754  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
755  OpFoldResult newIndex =
756  rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
757  for (auto [expandedShape, expandedIndex] :
758  llvm::zip(expandedDimsShape, expandedIndices)) {
759  AffineExpr idx, acc, shape;
760  bindDims(rewriter.getContext(), idx, acc);
761  bindSymbols(rewriter.getContext(), shape);
763  rewriter, indexOp.getLoc(), idx + acc * shape,
764  ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765  }
766  Value newIndexVal =
767  getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768  rewriter.replaceOp(indexOp, newIndexVal);
769  }
770 }
771 
772 // Create an expanded transpose op.
773 // the reassociation map is already permuted hence we inverse permute and then
774 // flatten it. Then we inverse permute it again to get the final expanded
775 // transpose permutation. For example,
776 //
777 // permutation = [2, 0, 1]
778 // reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779 //
780 // inverse permutation = [1, 2, 0]
781 // applied to reassocation_map and then flattened becomes
782 // flatened permutation = [2, 3, 4, 5, 0, 1]
783 // final permuation is the inverse of the flattened permutation.
784 //
785 // Becomes
786 //
787 // permutation=[4, 5, 0, 1, 2, 3]
788 
790  TransposeOp transposeOp,
791  Value expandedInput, Value output,
792  ExpansionInfo &expansionInfo) {
793  SmallVector<int64_t> newPerm;
794  for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795  auto reassoc = expansionInfo.getExpandedDims(perm);
796  for (int64_t dim : reassoc) {
797  newPerm.push_back(dim);
798  }
799  }
800  return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
801  output, invertPermutationVector(newPerm));
802 }
803 
804 // Create an expanded generic op.
806  PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807  ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808  ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809  // The iterator types of the expanded op are all parallel.
810  SmallVector<utils::IteratorType> iteratorTypes(
811  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812 
813  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814  for (auto j : expansionInfo.getExpandedDims(i))
815  iteratorTypes[j] = type;
816 
817  Operation *fused = rewriter.create<GenericOp>(
818  linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
819  expandedOpIndexingMaps, iteratorTypes);
820 
821  Region &fusedRegion = fused->getRegion(0);
822  Region &originalRegion = linalgOp->getRegion(0);
823  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824 
825  // Update the index accesses after the expansion.
826  updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827  expansionInfo);
828 
829  return fused;
830 }
831 
832 // Create an expanded fused op that retains the name for certain ops
833 // such as fill, copy and transpose and produce a generic op for
834 // rest of linalg ops.
835 static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836  TypeRange resultTypes,
837  ArrayRef<Value> expandedOpOperands,
838  ArrayRef<Value> outputs,
839  ArrayRef<AffineMap> expandedOpIndexingMaps,
840  ExpansionInfo &expansionInfo) {
841 
842  return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843  .Case<TransposeOp>([&](TransposeOp transposeOp) {
844  return createExpandedTransposeOp(rewriter, transposeOp,
845  expandedOpOperands[0], outputs[0],
846  expansionInfo);
847  })
848  .Case<FillOp, CopyOp>([&](Operation *op) {
849  return clone(rewriter, linalgOp, resultTypes,
850  llvm::to_vector(llvm::concat<Value>(
851  llvm::to_vector(expandedOpOperands),
852  llvm::to_vector(outputs))));
853  })
854  .Default([&](Operation *op) {
855  return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856  expandedOpOperands, outputs,
857  expansionInfo, expandedOpIndexingMaps);
858  });
859 }
860 
861 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863 /// that those conditions have been satisfied.
864 static std::optional<SmallVector<Value>>
865 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866  OpOperand *fusableOpOperand,
867  PatternRewriter &rewriter) {
868  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869  "preconditions for fuse operation failed");
870 
871  Location loc = linalgOp.getLoc();
872  SmallVector<OpFoldResult> expandedShape, collapsedShape;
873  SmallVector<AffineMap, 4> reassociationIndices;
874  Value src;
875  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876  // Try to move the dynamic dimensions in output shape before the `linalgOp`
877  // to maintain SSA validity
878  if (failed(moveValueDefinitions(
879  rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880  return std::nullopt;
881 
882  expandedShape = expandingReshapeOp.getMixedOutputShape();
883  reassociationIndices = expandingReshapeOp.getReassociationMaps();
884  src = expandingReshapeOp.getSrc();
885  } else {
886  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887  if (!collapsingReshapeOp)
888  return std::nullopt;
889 
890  expandedShape = tensor::getMixedSizes(
891  rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892  reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893  src = collapsingReshapeOp.getSrc();
894  }
895 
896  ExpansionInfo expansionInfo;
897  if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898  reassociationIndices, expandedShape,
899  rewriter)))
900  return std::nullopt;
901 
902  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
905  }));
906 
907  // Set insertion point to the generic op.
908  OpBuilder::InsertionGuard g(rewriter);
909  rewriter.setInsertionPoint(linalgOp);
910 
911  SmallVector<Value> expandedOpOperands;
912  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914  if (opOperand == fusableOpOperand) {
915  expandedOpOperands.push_back(src);
916  continue;
917  }
918  if (auto opOperandType =
919  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921  SmallVector<OpFoldResult> expandedOperandShape;
922  RankedTensorType expandedOperandType;
923  std::tie(expandedOperandShape, expandedOperandType) =
924  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925  if (expandedOperandType != opOperand->get().getType()) {
926  // Reshape the operand to get the right type.
927  SmallVector<ReassociationIndices> reassociation =
928  getReassociationForExpansion(indexingMap, expansionInfo);
930  [&](const Twine &msg) {
931  return rewriter.notifyMatchFailure(linalgOp, msg);
932  },
933  opOperandType.getShape(), expandedOperandType.getShape(),
934  reassociation,
935  /*isExpandingReshape=*/true)))
936  return std::nullopt;
937  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
938  loc, expandedOperandType, opOperand->get(), reassociation,
939  expandedOperandShape));
940  continue;
941  }
942  }
943  expandedOpOperands.push_back(opOperand->get());
944  }
945 
946  SmallVector<Value> outputs;
947  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950  SmallVector<OpFoldResult> expandedOutputShape;
951  RankedTensorType expandedOutputType;
952  std::tie(expandedOutputShape, expandedOutputType) =
953  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954  if (expandedOutputType != opOperand.get().getType()) {
955  SmallVector<ReassociationIndices> reassociation =
956  getReassociationForExpansion(indexingMap, expansionInfo);
958  [&](const Twine &msg) {
959  return rewriter.notifyMatchFailure(linalgOp, msg);
960  },
961  opOperandType.getShape(), expandedOutputType.getShape(),
962  reassociation,
963  /*isExpandingReshape=*/true)))
964  return std::nullopt;
965  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
966  loc, expandedOutputType, opOperand.get(), reassociation,
967  expandedOutputShape));
968  } else {
969  outputs.push_back(opOperand.get());
970  }
971  }
972 
973  TypeRange resultTypes = ValueRange(outputs).getTypes();
974  Operation *fusedOp =
975  createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976  outputs, expandedOpIndexingMaps, expansionInfo);
977  // Reshape the result values to their original shape if this is a collapsing
978  // reshape folded into its consumer.
979  SmallVector<Value> resultVals;
980  for (OpResult opResult : linalgOp->getOpResults()) {
981  int64_t resultNumber = opResult.getResultNumber();
982  if (resultTypes[resultNumber] != opResult.getType()) {
983  SmallVector<ReassociationIndices> reassociation =
985  linalgOp.getMatchingIndexingMap(
986  linalgOp.getDpsInitOperand(resultNumber)),
987  expansionInfo);
988  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
989  linalgOp.getLoc(), opResult.getType(),
990  fusedOp->getResult(resultNumber), reassociation));
991  } else {
992  resultVals.push_back(fusedOp->getResult(resultNumber));
993  }
994  }
995  // Assuming a single result.
996  return resultVals;
997 }
998 
999 namespace {
1000 
1001 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003 /// in the consumer is expanded.
1004 class FoldWithProducerReshapeOpByExpansion
1005  : public OpInterfaceRewritePattern<LinalgOp> {
1006 public:
1007  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008  ControlFusionFn foldReshapes,
1009  PatternBenefit benefit = 1)
1010  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011  controlFoldingReshapes(std::move(foldReshapes)) {}
1012 
1013  LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014  PatternRewriter &rewriter) const override {
1015  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016  tensor::CollapseShapeOp reshapeOp =
1017  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018  if (!reshapeOp)
1019  continue;
1020  // Fold only if
1021  // - The tensor reshape op is folding.
1022  // - All constraints of fusing with reshape by expansion are met.
1023  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024  (!controlFoldingReshapes(opOperand)))
1025  continue;
1026 
1027  std::optional<SmallVector<Value>> replacementValues =
1028  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029  if (!replacementValues)
1030  return failure();
1031  rewriter.replaceOp(linalgOp, *replacementValues);
1032  return success();
1033  }
1034  return failure();
1035  }
1036 
1037 private:
1038  ControlFusionFn controlFoldingReshapes;
1039 };
1040 
1041 class FoldPadWithProducerReshapeOpByExpansion
1042  : public OpRewritePattern<tensor::PadOp> {
1043 public:
1044  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1045  ControlFusionFn foldReshapes,
1046  PatternBenefit benefit = 1)
1047  : OpRewritePattern<tensor::PadOp>(context, benefit),
1048  controlFoldingReshapes(std::move(foldReshapes)) {}
1049 
1050  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051  PatternRewriter &rewriter) const override {
1052  tensor::CollapseShapeOp reshapeOp =
1053  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1054  if (!reshapeOp)
1055  return failure();
1056  if (!reshapeOp->hasOneUse())
1057  return failure();
1058 
1059  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1060  return rewriter.notifyMatchFailure(padOp,
1061  "fusion blocked by control function");
1062  }
1063 
1064  ArrayRef<int64_t> low = padOp.getStaticLow();
1065  ArrayRef<int64_t> high = padOp.getStaticHigh();
1066  SmallVector<ReassociationIndices> reassociations =
1067  reshapeOp.getReassociationIndices();
1068 
1069  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070  if (reInd.size() != 1 && (l != 0 || h != 0))
1071  return failure();
1072  }
1073 
1074  SmallVector<OpFoldResult> newLow, newHigh;
1075  RankedTensorType expandedType = reshapeOp.getSrcType();
1076  RankedTensorType paddedType = padOp.getResultType();
1077  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1078  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079  if (reInd.size() == 1) {
1080  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081  }
1082  for (size_t i = 0; i < reInd.size(); ++i) {
1083  newLow.push_back(padOp.getMixedLowPad()[idx]);
1084  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1085  }
1086  }
1087 
1088  Location loc = padOp->getLoc();
1089  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090  auto newPadOp = rewriter.create<tensor::PadOp>(
1091  loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092  padOp.getConstantPaddingValue(), padOp.getNofold());
1093 
1094  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1096 
1097  return success();
1098  }
1099 
1100 private:
1101  ControlFusionFn controlFoldingReshapes;
1102 };
1103 
1104 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1105 /// by expanding the dimensionality of the loop in the producer op.
1106 struct FoldReshapeWithGenericOpByExpansion
1107  : public OpRewritePattern<tensor::ExpandShapeOp> {
1108 
1109  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1110  ControlFusionFn foldReshapes,
1111  PatternBenefit benefit = 1)
1112  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113  controlFoldingReshapes(std::move(foldReshapes)) {}
1114 
1115  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1116  PatternRewriter &rewriter) const override {
1117  // Fold only if all constraints of fusing with reshape by expansion are met.
1118  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119  if (!producerResult) {
1120  return rewriter.notifyMatchFailure(reshapeOp,
1121  "source not produced by an operation");
1122  }
1123 
1124  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1125  if (!producer) {
1126  return rewriter.notifyMatchFailure(reshapeOp,
1127  "producer not a generic op");
1128  }
1129 
1131  producer,
1132  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1133  return rewriter.notifyMatchFailure(
1134  reshapeOp, "failed preconditions of fusion with producer generic op");
1135  }
1136 
1137  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1138  return rewriter.notifyMatchFailure(reshapeOp,
1139  "fusion blocked by control function");
1140  }
1141 
1142  std::optional<SmallVector<Value>> replacementValues =
1144  producer, reshapeOp,
1145  producer.getDpsInitOperand(producerResult.getResultNumber()),
1146  rewriter);
1147  if (!replacementValues) {
1148  return rewriter.notifyMatchFailure(reshapeOp,
1149  "fusion by expansion failed");
1150  }
1151 
1152  // Find the replacement for the reshape op. Since the replacements have the
1153  // same type as the returns of the original generic op, the consumer reshape
1154  // op can be replaced by the source of the collapse_shape op that defines
1155  // the replacement.
1156  Value reshapeReplacement =
1157  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158  .getResultNumber()];
1159  if (auto collapseOp =
1160  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1161  reshapeReplacement = collapseOp.getSrc();
1162  }
1163  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1164  rewriter.replaceOp(producer, *replacementValues);
1165  return success();
1166  }
1167 
1168 private:
1169  ControlFusionFn controlFoldingReshapes;
1170 };
1171 } // namespace
1172 
1173 //===---------------------------------------------------------------------===//
1174 // Methods and patterns to fuse reshape with linalg.generic operations by
1175 // contraction of dimensions.
1176 //===---------------------------------------------------------------------===//
1177 
1178 /// For a given list of indices in the range of the `indexingMap` that are
1179 /// folded, return the indices of the corresponding domain. Return
1180 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1181 /// reassociation are distinct.
1182 static ReassociationIndices
1184  ReassociationIndicesRef rangeReassociation) {
1185  assert(indexingMap.isProjectedPermutation() &&
1186  "expected projected permutation");
1187 
1188  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1189  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1191  }));
1192  // The projected permutation semantics ensures that there is no repetition of
1193  // the domain indices.
1194  return domainReassociation;
1195 }
1196 
1197 /// For a given `dimSequence`, check if the sequence is conserved in the
1198 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1199 /// Non-existence of the sequence returns true as well.
1201  ReassociationIndicesRef dimSequence) {
1202  assert(!dimSequence.empty() &&
1203  "expected non-empty list for dimension sequence");
1204  assert(indexingMap.isProjectedPermutation() &&
1205  "expected indexing map to be projected permutation");
1206 
1207  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1209 
1210  unsigned dimSequenceStart = dimSequence[0];
1211  for (const auto &expr : enumerate(indexingMap.getResults())) {
1212  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1213  // 1. Check if this start of the sequence.
1214  if (dimInMapStart == dimSequenceStart) {
1215  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1216  return false;
1217  // 1a. Check if sequence is preserved.
1218  for (const auto &dimInSequence : enumerate(dimSequence)) {
1219  unsigned dimInMap =
1220  cast<AffineDimExpr>(
1221  indexingMap.getResult(expr.index() + dimInSequence.index()))
1222  .getPosition();
1223  if (dimInMap != dimInSequence.value())
1224  return false;
1225  }
1226  // Found the sequence. Projected permutation
1227  // enforces that all AffineDimExprs in the result are unique, so no
1228  // further checks are needed.
1229  return true;
1230  }
1231  // 2. If position in the expr (which is of type AffineDimExpr) is part
1232  // of sequence, return false here. This implies the entire sequence does not
1233  // exist in the indexing map.
1234  if (sequenceElements.count(dimInMapStart))
1235  return false;
1236  }
1237  // 3. No element of sequence found. Return true.
1238  return true;
1239 }
1240 
1243  return llvm::all_of(maps, [&](AffineMap map) {
1244  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1245  return isDimSequencePreserved(map, dimSequence);
1246  });
1247  });
1248 }
1249 
1250 // Return the list of dimensions of the iteration domain that can be
1251 // collapsed to allow for fusion with the a producer that is an expand_shape
1252 // operation. If all dimensions created by expansion can be collapsed in the
1253 // iteration space then the reshape is defunct.
1254 //
1255 // Example:
1256 //
1257 // ```mlir
1258 // #map = affine_map<(d0, d1) -> (d0, d1)>
1259 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1260 // %2 = tensor.empty [..] : tensor<?x4xf32>
1261 // %3 = linalg.generic {
1262 // indexing_maps = [#map, #map],
1263 // iterator_types = ["parallel" ,"parallel"]}
1264 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1265 // ```
1266 //
1267 // can be fused by collapsing the dimensions of the iteration space.
1268 //
1269 // ```mlir
1270 // #map = affine_map<(d0) -> (d0)>
1271 // %2 = tensor.empty [..] : tensor<?xf32>
1272 // %3 = linalg.generic {
1273 // indexing_maps = [#map, #map],
1274 // iterator_types = ["parallel"]}
1275 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1276 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1277 // ```
1278 //
1279 // In the following example,
1280 //
1281 // ```mlir
1282 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1283 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1284 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1285 // %2 = tensor.empty [..] : tensor<4x?xf32>
1286 // %2 = linalg.generic {
1287 // indexing_maps = [#map0, #map1],
1288 // iterator_types = ["parallel" ,"parallel"]}
1289 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1290 // ```
1291 //
1292 // the reshape cannot be fused with the generic op by collapsing the op
1293 // dimensions since the indexing maps will have to contain mods and divs
1294 // to preserve the accesses pattern. When no dimensions of the iteration
1295 // space are collapsable and empty vector is returned.
1297 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1298  ArrayRef<ReassociationIndices> reassociation) {
1299  // Some basic checks for this fusion to be valid.
1300  if (!genericOp.hasPureTensorSemantics())
1301  return {};
1302 
1303  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1304  return map.isProjectedPermutation();
1305  })) {
1306  return {};
1307  }
1308 
1309  // Compute all the loops with the reduction iterator types.
1310  SmallVector<unsigned> reductionDims;
1311  genericOp.getReductionDims(reductionDims);
1312 
1313  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315  auto iteratorTypes = genericOp.getIteratorTypesArray();
1316  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1317  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1318  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1319 
1320  // Ignore dims that are not folded.
1321  if (foldedRangeDims.size() == 1)
1322  continue;
1323 
1324  ReassociationIndices foldedIterationSpaceDims =
1325  getDomainReassociation(indexingMap, foldedRangeDims);
1326 
1327  // Check that the folded iteration dims do not contain already processed
1328  // dims.
1329  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330  return processedIterationDims.count(dim);
1331  }))
1332  continue;
1333 
1334  // Check that all folded iterator types are all parallel or all reductions.
1335  utils::IteratorType startIteratorType =
1336  iteratorTypes[foldedIterationSpaceDims[0]];
1337  if (!isParallelIterator(startIteratorType) &&
1338  !isReductionIterator(startIteratorType))
1339  continue;
1340  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341  return iteratorTypes[dim] != startIteratorType;
1342  }))
1343  continue;
1344 
1345  // If the folded dimensions correspond to a "reduction" iterator type,
1346  // the folded dimensions need to be "in-order". Strictly speaking this is
1347  // not necessary, for reductions that are associative and commutative, but
1348  // using a more strict definition of reduction for now.
1349  if (isReductionIterator(startIteratorType)) {
1350  bool isContiguous = false;
1351  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1352  // Move window in `reductionDims` to start of the folded iteration dims.
1353  if (startDim.value() != foldedIterationSpaceDims[0])
1354  continue;
1355  // If sizes doesnt match, trivial not contiguous. This condition should
1356  // not be hit.
1357  if (startDim.index() + foldedIterationSpaceDims.size() >
1358  reductionDims.size())
1359  break;
1360  // Check that the contiguity is maintained.
1361  isContiguous = true;
1362  for (const auto &foldedDim :
1363  llvm::enumerate(foldedIterationSpaceDims)) {
1364  if (reductionDims[foldedDim.index() + startDim.index()] !=
1365  foldedDim.value()) {
1366  isContiguous = false;
1367  break;
1368  }
1369  }
1370  break;
1371  }
1372  if (!isContiguous)
1373  continue;
1374  }
1375 
1376  // Check that the sequence is preserved in all indexing maps.
1377  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1378  [&](AffineMap indexingMap) {
1379  return !isDimSequencePreserved(indexingMap,
1380  foldedIterationSpaceDims);
1381  }))
1382  continue;
1383 
1384  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1385  foldedIterationSpaceDims.end());
1386  iterationSpaceReassociation.emplace_back(
1387  std::move(foldedIterationSpaceDims));
1388  }
1389 
1390  return iterationSpaceReassociation;
1391 }
1392 
1393 /// Helper class to carry state while collapsing the `linalg.generic` op.
1394 namespace {
1395 class CollapsingInfo {
1396 public:
1397  LogicalResult initialize(unsigned origNumLoops,
1398  ArrayRef<ReassociationIndices> foldedIterationDims) {
1399  llvm::SmallDenseSet<int64_t, 4> processedDims;
1400  // Find all the dims that are folded.
1401  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1402  if (foldedIterationDim.empty())
1403  continue;
1404  // If the folded dims contain dims already folded, that's illegal
1405  // specification. Repetition within a list is also illegal.
1406  for (auto dim : foldedIterationDim) {
1407  if (dim >= origNumLoops)
1408  return failure();
1409  if (processedDims.count(dim))
1410  return failure();
1411  processedDims.insert(dim);
1412  }
1413  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1414  foldedIterationDim.end());
1415  }
1416  if (processedDims.size() > origNumLoops)
1417  return failure();
1418 
1419  // Add all the preserved dims of the original op as single
1420  // elements to `collapsedOpToOrigOpIterationDim`.
1421  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1422  if (processedDims.count(dim))
1423  continue;
1424  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1425  }
1426 
1427  llvm::sort(collapsedOpToOrigOpIterationDim,
1429  return lhs[0] < rhs[0];
1430  });
1431  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1432  for (const auto &foldedDims :
1433  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1434  for (const auto &dim : enumerate(foldedDims.value()))
1435  origOpToCollapsedOpIterationDim[dim.value()] =
1436  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1437  }
1438  return success();
1439  }
1440 
1441  /// Return mapping from collapsed loop domain to original loop domain.
1442  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1443  return collapsedOpToOrigOpIterationDim;
1444  }
1445 
1446  /// Return mapping from original loop domain to collapsed loop domain. The
1447  /// mapping is a pair. First value is the dimension in the collapsed loop that
1448  /// the original loop is mapped to. Second is the relative position in folded
1449  /// list of this domain. For example if the original loop domain is 3D, and
1450  /// the collapsed loop domain is folding all of it, i.e.
1451  ///
1452  /// ```
1453  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1454  /// ```
1455  ///
1456  /// then
1457  ///
1458  /// ```
1459  /// origOpToCollapsedOpMapping[0] = {0, 0};
1460  /// origOpToCollapsedOpMapping[1] = {0, 1};
1461  /// origOpToCollapsedOpMapping[2] = {0, 2};
1462  /// origOpToCollapsedOpMapping[3] = {1, 0};
1463  /// origOpToCollapsedOpMapping[4] = {1, 1};
1464  /// ```
1465  ///
1466  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1467  return origOpToCollapsedOpIterationDim;
1468  }
1469 
1470  /// Return the collapsed op iteration domain rank.
1471  unsigned getCollapsedOpIterationRank() const {
1472  return collapsedOpToOrigOpIterationDim.size();
1473  }
1474 
1475 private:
1476  /// Map from the iteration domain index in collapsed op to the iteration
1477  /// domain indices in the original op.
1478  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1479 
1480  /// Map from iteration domain index in the original op to the iteration domain
1481  /// index in the collapsed op.
1482  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1483 };
1484 } // namespace
1485 
1486 /// Get the iterator types for the collapsed operation given the original
1487 /// iterator types and collapsed dimensions.
1490  const CollapsingInfo &collapsingInfo) {
1491  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1492  for (ReassociationIndicesRef foldedIterDims :
1493  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1494  assert(!foldedIterDims.empty() &&
1495  "reassociation indices expected to have non-empty sets");
1496  // Just pick the iterator type of the first folded dim. Pre-condition checks
1497  // expected to have checked that iterator types of all folded dimensions are
1498  // the same.
1499  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1500  }
1501  return collapsedIteratorTypes;
1502 }
1503 
1504 /// Compute the indexing map in the collapsed op that corresponds to the given
1505 /// `indexingMap` of the original operation.
1506 static AffineMap
1508  const CollapsingInfo &collapsingInfo) {
1509  MLIRContext *context = indexingMap.getContext();
1510  assert(indexingMap.isProjectedPermutation() &&
1511  "expected indexing map to be projected permutation");
1512  SmallVector<AffineExpr> resultExprs;
1513  auto origOpToCollapsedOpMapping =
1514  collapsingInfo.getOrigOpToCollapsedOpMapping();
1515  for (auto expr : indexingMap.getResults()) {
1516  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1517  // If the dim is not the first of the collapsed dim, do nothing.
1518  if (origOpToCollapsedOpMapping[dim].second != 0)
1519  continue;
1520  // The next n-dims are guaranteed to be collapsed. So just use the
1521  // iteration dimension of the collapsed op.
1522  resultExprs.push_back(
1523  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1524  }
1525  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1526  resultExprs, context);
1527 }
1528 
1529 /// Return the `reassociation` indices to use to collapse the operand when the
1530 /// iteration space of a generic op is collapsed.
1533  const CollapsingInfo &collapsingInfo) {
1534  unsigned counter = 0;
1535  SmallVector<ReassociationIndices> operandReassociation;
1536  auto origOpToCollapsedOpMapping =
1537  collapsingInfo.getOrigOpToCollapsedOpMapping();
1538  auto collapsedOpToOrigOpMapping =
1539  collapsingInfo.getCollapsedOpToOrigOpMapping();
1540  while (counter < indexingMap.getNumResults()) {
1541  unsigned dim =
1542  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1543  // This is the start of a collapsed dimensions of the iteration that
1544  // is gauranteed to be preserved in the indexing map. The number of folded
1545  // dims is obtained from the collapsed op to original op mapping.
1546  unsigned numFoldedDims =
1547  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1548  .size();
1549  if (origOpToCollapsedOpMapping[dim].second == 0) {
1550  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1551  operandReassociation.emplace_back(range.begin(), range.end());
1552  }
1553  counter += numFoldedDims;
1554  }
1555  return operandReassociation;
1556 }
1557 
1558 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1559 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1560  OpOperand *opOperand,
1561  const CollapsingInfo &collapsingInfo,
1562  OpBuilder &builder) {
1563  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1564  SmallVector<ReassociationIndices> operandReassociation =
1565  getOperandReassociation(indexingMap, collapsingInfo);
1566 
1567  // If the number of entries in the reassociation for the operand is same as
1568  // the number of results of the indexing map, then nothing to do for this
1569  // operand.
1570  Value operand = opOperand->get();
1571  if (operandReassociation.size() == indexingMap.getNumResults())
1572  return operand;
1573 
1574  // Insert a reshape to collapse the dimensions.
1575  if (isa<MemRefType>(operand.getType())) {
1576  return builder
1577  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1578  .getResult();
1579  }
1580  return builder
1581  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1582  .getResult();
1583 }
1584 
1585 /// Modify the `linalg.index` operations in the original generic op, to its
1586 /// value in the collapsed operation.
1588  Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1589  ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1590  OpBuilder::InsertionGuard g(rewriter);
1591  rewriter.setInsertionPointToStart(block);
1592 
1593  // Collect all the original index ops.
1594  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1595 
1596  // For each folded dimension list resolve the original induction variable
1597  // values in terms of the folded dimension induction variable.
1598  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1599  // can be inverted to
1600  // i2 = i_{folded} % d2
1601  // i1 = (i_{folded} / d2) % d1
1602  // i0 = i_{folded} / (d1 * d2)
1603  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1604  for (auto foldedDims :
1605  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1606  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1607  Value newIndexVal =
1608  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1609  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1610  Value loopDim =
1611  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1612  indexReplacementVals[dim] =
1613  rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1614  newIndexVal =
1615  rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1616  }
1617  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1618  }
1619 
1620  for (auto indexOp : indexOps) {
1621  auto dim = indexOp.getDim();
1622  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1623  }
1624 }
1625 
1627  const CollapsingInfo &collapsingInfo,
1628  RewriterBase &rewriter,
1629  SmallVectorImpl<Value> &inputOperands,
1630  SmallVectorImpl<Value> &outputOperands,
1631  SmallVectorImpl<Type> &resultTypes) {
1632  Location loc = op->getLoc();
1633  inputOperands =
1634  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1635  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1636  rewriter);
1637  });
1638 
1639  // Get the output operands and result types.
1640  resultTypes.reserve(op.getNumDpsInits());
1641  outputOperands.reserve(op.getNumDpsInits());
1642  for (OpOperand &output : op.getDpsInitsMutable()) {
1643  Value newOutput =
1644  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1645  outputOperands.push_back(newOutput);
1646  // If the op has "buffer semantics", then the init operands are ranked
1647  // memrefs and the op has no results.
1648  if (!op.hasPureBufferSemantics())
1649  resultTypes.push_back(newOutput.getType());
1650  }
1651 }
1652 
1653 /// Clone a `LinalgOp` to a collapsed version of same name
1654 template <typename OpTy>
1655 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1656  const CollapsingInfo &collapsingInfo) {
1657  return nullptr;
1658 }
1659 
1660 /// Collapse any `LinalgOp` that does not require any specialization such as
1661 /// indexing_maps, iterator_types, etc.
1662 template <>
1663 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1664  const CollapsingInfo &collapsingInfo) {
1665  SmallVector<Value> inputOperands, outputOperands;
1666  SmallVector<Type> resultTypes;
1667  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1668  outputOperands, resultTypes);
1669 
1670  return clone(
1671  rewriter, origOp, resultTypes,
1672  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1673 }
1674 
1675 /// Collapse a `GenericOp`
1676 template <>
1678  GenericOp origOp,
1679  const CollapsingInfo &collapsingInfo) {
1680  SmallVector<Value> inputOperands, outputOperands;
1681  SmallVector<Type> resultTypes;
1682  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1683  outputOperands, resultTypes);
1684  SmallVector<AffineMap> indexingMaps(
1685  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1686  return getCollapsedOpIndexingMap(map, collapsingInfo);
1687  }));
1688 
1690  origOp.getIteratorTypesArray(), collapsingInfo));
1691 
1692  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1693  origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1694  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1695  Block *origOpBlock = &origOp->getRegion(0).front();
1696  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1697  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1698  collapsedOpBlock->getArguments());
1699  return collapsedOp;
1700 }
1701 
1702 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1703  RewriterBase &rewriter) {
1704  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1705  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1706  } else {
1707  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1708  }
1709 }
1710 
1711 /// Implementation of fusion with reshape operation by collapsing dimensions.
1712 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1713  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1714  RewriterBase &rewriter) {
1715  // Bail on trivial no-op cases.
1716  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1717  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1718  return foldedDims.size() <= 1;
1719  }))
1720  return failure();
1721 
1722  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1723  if (hasPureBufferSemantics &&
1724  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1725  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1726  if (!memRefToCollapse)
1727  return true;
1728 
1729  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1730  memRefToCollapse, foldedIterationDims);
1731  }))
1732  return rewriter.notifyMatchFailure(op,
1733  "memref is not guaranteed collapsible");
1734 
1735  CollapsingInfo collapsingInfo;
1736  if (failed(
1737  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1738  return rewriter.notifyMatchFailure(
1739  op, "illegal to collapse specified dimensions");
1740  }
1741 
1742  // Bail on non-canonical ranges.
1743  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1744  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1745  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1746  return cast<IntegerAttr>(attr).getInt() == value;
1747  llvm::APInt actual;
1748  return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1749  actual.getSExtValue() == value;
1750  };
1751  if (!llvm::all_of(loopRanges, [&](Range range) {
1752  return opFoldIsConstantValue(range.offset, 0) &&
1753  opFoldIsConstantValue(range.stride, 1);
1754  })) {
1755  return rewriter.notifyMatchFailure(
1756  op, "expected all loop ranges to have zero start and unit stride");
1757  }
1758 
1759  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1760 
1761  Location loc = op->getLoc();
1762  SmallVector<OpFoldResult> loopBound =
1763  llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1764 
1765  if (collapsedOp.hasIndexSemantics()) {
1766  // Collect the loop range of the generic op.
1767  OpBuilder::InsertionGuard g(rewriter);
1768  rewriter.setInsertionPoint(collapsedOp);
1769  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1770  collapsingInfo, loopBound, rewriter);
1771  }
1772 
1773  // Insert expanding reshape for the result to get back the original result
1774  // type.
1775  SmallVector<Value> results;
1776  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1777  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1778  auto originalResultType =
1779  cast<ShapedType>(originalResult.value().getType());
1780  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1781  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1782  AffineMap indexingMap =
1783  op.getIndexingMapMatchingResult(originalResult.value());
1784  SmallVector<ReassociationIndices> reassociation =
1785  getOperandReassociation(indexingMap, collapsingInfo);
1786  assert(
1787  indexingMap.isProjectedPermutation() &&
1788  "Expected indexing map to be a projected permutation for collapsing");
1789  SmallVector<OpFoldResult> resultShape =
1790  applyPermutationMap(indexingMap, ArrayRef(loopBound));
1791  Value result;
1792  if (isa<MemRefType>(collapsedOpResult.getType())) {
1793  MemRefType expandShapeResultType = MemRefType::get(
1794  originalResultType.getShape(), originalResultType.getElementType());
1795  result = rewriter.create<memref::ExpandShapeOp>(
1796  loc, expandShapeResultType, collapsedOpResult, reassociation,
1797  resultShape);
1798  } else {
1799  result = rewriter.create<tensor::ExpandShapeOp>(
1800  loc, originalResultType, collapsedOpResult, reassociation,
1801  resultShape);
1802  }
1803  results.push_back(result);
1804  } else {
1805  results.push_back(collapsedOpResult);
1806  }
1807  }
1808  return CollapseResult{results, collapsedOp};
1809 }
1810 
1811 namespace {
1812 
1813 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1814 /// contracting dimensions of the loop.
1815 class FoldWithProducerReshapeOpByCollapsing
1816  : public OpRewritePattern<GenericOp> {
1817 public:
1818  // TODO : support fusion with all linalg ops, not just generic.
1819  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1820  ControlFusionFn foldReshapes,
1821  PatternBenefit benefit = 1)
1822  : OpRewritePattern<GenericOp>(context, benefit),
1823  controlFoldingReshapes(std::move(foldReshapes)) {}
1824 
1825  LogicalResult matchAndRewrite(GenericOp genericOp,
1826  PatternRewriter &rewriter) const override {
1827  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1828  tensor::ExpandShapeOp reshapeOp =
1829  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1830  if (!reshapeOp)
1831  continue;
1832 
1833  SmallVector<ReassociationIndices> collapsableIterationDims =
1834  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1835  reshapeOp.getReassociationIndices());
1836  if (collapsableIterationDims.empty() ||
1837  !controlFoldingReshapes(&opOperand)) {
1838  continue;
1839  }
1840 
1841  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1842  genericOp, collapsableIterationDims, rewriter);
1843  if (!collapseResult) {
1844  return rewriter.notifyMatchFailure(
1845  genericOp, "failed to do the fusion by collapsing transformation");
1846  }
1847 
1848  rewriter.replaceOp(genericOp, collapseResult->results);
1849  return success();
1850  }
1851  return failure();
1852  }
1853 
1854 private:
1855  ControlFusionFn controlFoldingReshapes;
1856 };
1857 
1858 /// Pattern to fold a tensor.collapse_shape op with its producer generic op
1859 /// by expanding the dimensionality of the loop in the producer op.
1860 struct FoldReshapeWithGenericOpByCollapsing
1861  : public OpRewritePattern<tensor::CollapseShapeOp> {
1862 
1863  FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1864  ControlFusionFn foldReshapes,
1865  PatternBenefit benefit = 1)
1866  : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1867  controlFoldingReshapes(std::move(foldReshapes)) {}
1868 
1869  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1870  PatternRewriter &rewriter) const override {
1871  // Fold only if all constraints of fusing with reshape by collapsing are
1872  // met.
1873  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1874  if (!producerResult) {
1875  return rewriter.notifyMatchFailure(reshapeOp,
1876  "source not produced by an operation");
1877  }
1878 
1879  // TODO : support fusion with all linalg producers, not just generic.
1880  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1881  if (!producer) {
1882  return rewriter.notifyMatchFailure(reshapeOp,
1883  "producer not a generic op");
1884  }
1885 
1886  SmallVector<ReassociationIndices> collapsableIterationDims =
1888  producer,
1889  producer.getDpsInitOperand(producerResult.getResultNumber()),
1890  reshapeOp.getReassociationIndices());
1891  if (collapsableIterationDims.empty()) {
1892  return rewriter.notifyMatchFailure(
1893  reshapeOp, "failed preconditions of fusion with producer generic op");
1894  }
1895 
1896  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1897  return rewriter.notifyMatchFailure(reshapeOp,
1898  "fusion blocked by control function");
1899  }
1900 
1901  std::optional<CollapseResult> collapseResult =
1902  collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1903  if (!collapseResult) {
1904  return rewriter.notifyMatchFailure(
1905  producer, "failed to do the fusion by collapsing transformation");
1906  }
1907 
1908  if (!collapseResult) {
1909  return rewriter.notifyMatchFailure(reshapeOp,
1910  "fusion by expansion failed");
1911  }
1912 
1913  // Find the replacement for the reshape op. Since the replacements have the
1914  // same type as the returns of the original generic op, the consumer reshape
1915  // op can be replaced by the source of the expand_shape op that defines
1916  // the replacement.
1917  Value reshapeReplacement =
1918  (collapseResult
1919  ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
1920  if (auto expandOp =
1921  reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
1922  reshapeReplacement = expandOp.getSrc();
1923  }
1924  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1925  rewriter.replaceOp(producer, collapseResult->results);
1926  return success();
1927  }
1928 
1929 private:
1930  ControlFusionFn controlFoldingReshapes;
1931 };
1932 
1933 class FoldPadWithProducerReshapeOpByCollapsing
1934  : public OpRewritePattern<tensor::PadOp> {
1935 public:
1936  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1937  ControlFusionFn foldReshapes,
1938  PatternBenefit benefit = 1)
1939  : OpRewritePattern<tensor::PadOp>(context, benefit),
1940  controlFoldingReshapes(std::move(foldReshapes)) {}
1941 
1942  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1943  PatternRewriter &rewriter) const override {
1944  tensor::ExpandShapeOp reshapeOp =
1945  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1946  if (!reshapeOp)
1947  return failure();
1948  if (!reshapeOp->hasOneUse())
1949  return failure();
1950 
1951  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1952  return rewriter.notifyMatchFailure(padOp,
1953  "fusion blocked by control function");
1954  }
1955 
1956  ArrayRef<int64_t> low = padOp.getStaticLow();
1957  ArrayRef<int64_t> high = padOp.getStaticHigh();
1958  SmallVector<ReassociationIndices> reassociations =
1959  reshapeOp.getReassociationIndices();
1960 
1961  for (auto reInd : reassociations) {
1962  if (reInd.size() == 1)
1963  continue;
1964  if (llvm::any_of(reInd, [&](int64_t ind) {
1965  return low[ind] != 0 || high[ind] != 0;
1966  })) {
1967  return failure();
1968  }
1969  }
1970 
1971  SmallVector<OpFoldResult> newLow, newHigh;
1972  RankedTensorType collapsedType = reshapeOp.getSrcType();
1973  RankedTensorType paddedType = padOp.getResultType();
1974  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1975  SmallVector<OpFoldResult> expandedPaddedSizes(
1976  getMixedValues(reshapeOp.getStaticOutputShape(),
1977  reshapeOp.getOutputShape(), rewriter));
1978  AffineExpr d0, d1, d2;
1979  bindDims(rewriter.getContext(), d0, d1, d2);
1980  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1981  Location loc = reshapeOp->getLoc();
1982  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1983  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1984  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1985  if (reInd.size() == 1) {
1986  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1988  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1989  expandedPaddedSizes[reInd[0]] = paddedSize;
1990  }
1991  newLow.push_back(l);
1992  newHigh.push_back(h);
1993  }
1994 
1995  RankedTensorType collapsedPaddedType =
1996  paddedType.clone(collapsedPaddedShape);
1997  auto newPadOp = rewriter.create<tensor::PadOp>(
1998  loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1999  padOp.getConstantPaddingValue(), padOp.getNofold());
2000 
2001  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
2002  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2003  expandedPaddedSizes);
2004 
2005  return success();
2006  }
2007 
2008 private:
2009  ControlFusionFn controlFoldingReshapes;
2010 };
2011 
2012 /// Pattern to collapse dimensions.
2013 template <typename LinalgType>
2014 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2015 public:
2016  CollapseLinalgDimensions(MLIRContext *context,
2017  GetCollapsableDimensionsFn collapseDimensions,
2018  PatternBenefit benefit = 1)
2019  : OpRewritePattern<LinalgType>(context, benefit),
2020  controlCollapseDimension(std::move(collapseDimensions)) {}
2021 
2022  LogicalResult matchAndRewrite(LinalgType op,
2023  PatternRewriter &rewriter) const override {
2024  SmallVector<ReassociationIndices> collapsableIterationDims =
2025  controlCollapseDimension(op);
2026  if (collapsableIterationDims.empty())
2027  return failure();
2028 
2029  // Check if the specified list of dimensions to collapse is a valid list.
2030  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2031  collapsableIterationDims)) {
2032  return rewriter.notifyMatchFailure(
2033  op, "specified dimensions cannot be collapsed");
2034  }
2035 
2036  std::optional<CollapseResult> collapseResult =
2037  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2038  if (!collapseResult) {
2039  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2040  }
2041  rewriter.replaceOp(op, collapseResult->results);
2042  return success();
2043  }
2044 
2045 private:
2046  GetCollapsableDimensionsFn controlCollapseDimension;
2047 };
2048 
2049 } // namespace
2050 
2051 //===---------------------------------------------------------------------===//
2052 // Methods and patterns that fuse constants with linalg.generic operations.
2053 //===---------------------------------------------------------------------===//
2054 
2055 namespace {
2056 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2057 /// handle cases where the constant is not single-valued.
2058 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2059 public:
2060  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2061  : OpRewritePattern<GenericOp>(context, benefit) {}
2062 
2063  LogicalResult matchAndRewrite(GenericOp genericOp,
2064  PatternRewriter &rewriter) const override {
2065  if (!genericOp.hasPureTensorSemantics())
2066  return failure();
2067  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2068  Operation *def = opOperand->get().getDefiningOp();
2069  TypedAttr constantAttr;
2070  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2071  {
2072  DenseElementsAttr splatAttr;
2073  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2074  splatAttr.isSplat() &&
2075  splatAttr.getType().getElementType().isIntOrFloat()) {
2076  constantAttr = splatAttr.getSplatValue<TypedAttr>();
2077  return true;
2078  }
2079  }
2080  {
2081  IntegerAttr intAttr;
2082  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2083  constantAttr = intAttr;
2084  return true;
2085  }
2086  }
2087  {
2088  FloatAttr floatAttr;
2089  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2090  constantAttr = floatAttr;
2091  return true;
2092  }
2093  }
2094  return false;
2095  };
2096 
2097  auto resultValue = dyn_cast<OpResult>(opOperand->get());
2098  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2099  continue;
2100 
2101  // The operands and the indexing_maps of the fused operation the same as
2102  // the operands and indexing_maps of the generic operations with the
2103  // values at the constant index dropped.
2104  SmallVector<AffineMap> fusedIndexMaps;
2105  SmallVector<Value> fusedOperands;
2106  SmallVector<Location> fusedLocs{genericOp.getLoc()};
2107  fusedIndexMaps.reserve(genericOp->getNumOperands());
2108  fusedOperands.reserve(genericOp.getNumDpsInputs());
2109  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2110  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2111  if (inputOperand == opOperand)
2112  continue;
2113  Value inputValue = inputOperand->get();
2114  fusedIndexMaps.push_back(
2115  genericOp.getMatchingIndexingMap(inputOperand));
2116  fusedOperands.push_back(inputValue);
2117  fusedLocs.push_back(inputValue.getLoc());
2118  }
2119  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2120  fusedIndexMaps.push_back(
2121  genericOp.getMatchingIndexingMap(&outputOperand));
2122 
2123  // Check if the operation shapes to loops map is computable.
2124  if (!inversePermutation(
2125  concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2126  return rewriter.notifyMatchFailure(
2127  genericOp, "fused op loop bound computation failed");
2128  }
2129 
2130  // Create a constant scalar value from the splat constant.
2131  Value scalarConstant =
2132  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2133 
2134  SmallVector<Value> outputOperands = genericOp.getOutputs();
2135  auto fusedOp = rewriter.create<GenericOp>(
2136  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2137  /*inputs=*/fusedOperands,
2138  /*outputs=*/outputOperands,
2139  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2140  genericOp.getIteratorTypes(),
2141  /*doc=*/nullptr,
2142  /*library_call=*/nullptr);
2143 
2144  // Map the block argument corresponding to the replaced argument with the
2145  // scalar constant.
2146  Region &region = genericOp->getRegion(0);
2147  Block &entryBlock = *region.begin();
2148  IRMapping mapping;
2149  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2150  scalarConstant);
2151  Region &fusedRegion = fusedOp->getRegion(0);
2152  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2153  mapping);
2154  rewriter.replaceOp(genericOp, fusedOp->getResults());
2155  return success();
2156  }
2157  return failure();
2158  }
2159 };
2160 
2161 } // namespace
2162 
2163 //===---------------------------------------------------------------------===//
2164 // Miscellaneous patterns that help fusion.
2165 //===---------------------------------------------------------------------===//
2166 
2167 namespace {
2168 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2169 /// value of the `outs` operand is not used within the op. This is only
2170 /// implemented for `linalg.generic` operations for now, but should hold for all
2171 /// linalg structured ops.
2172 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2174 
2175  LogicalResult matchAndRewrite(GenericOp op,
2176  PatternRewriter &rewriter) const override {
2177  rewriter.startOpModification(op);
2178  bool modifiedOutput = false;
2179  Location loc = op.getLoc();
2180  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2181  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2182  Value operandVal = opOperand.get();
2183  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2184  if (!operandType)
2185  continue;
2186 
2187  // If outs is sparse, leave it to the sparsifier.
2189  continue;
2190 
2191  // If outs is already an `empty` operation, nothing to do.
2192  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2193  if (definingOp)
2194  continue;
2195  modifiedOutput = true;
2196  SmallVector<OpFoldResult> mixedSizes =
2197  tensor::getMixedSizes(rewriter, loc, operandVal);
2198  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2199  loc, mixedSizes, operandType.getElementType());
2200  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2201  }
2202  }
2203  if (!modifiedOutput) {
2204  rewriter.cancelOpModification(op);
2205  return failure();
2206  }
2207  rewriter.finalizeOpModification(op);
2208  return success();
2209  }
2210 };
2211 
2212 /// Fold linalg.fill into linalg.generic
2213 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2215 
2216  LogicalResult matchAndRewrite(GenericOp genericOp,
2217  PatternRewriter &rewriter) const override {
2218  if (!genericOp.hasPureTensorSemantics())
2219  return failure();
2220  bool fillFound = false;
2221  Block &payload = genericOp.getRegion().front();
2222  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2223  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2224  continue;
2225  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2226  if (!fillOp)
2227  continue;
2228  fillFound = true;
2229  Value fillVal = fillOp.value();
2230  auto resultType =
2231  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2232  Value convertedVal =
2233  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2234  /*isUnsignedCast =*/false);
2235  rewriter.replaceAllUsesWith(
2236  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2237  }
2238  return success(fillFound);
2239  }
2240 };
2241 } // namespace
2242 
2245  const ControlFusionFn &controlFoldingReshapes) {
2246  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2247  controlFoldingReshapes);
2248  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2249  controlFoldingReshapes);
2250  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2251  controlFoldingReshapes);
2252 }
2253 
2256  const ControlFusionFn &controlFoldingReshapes) {
2257  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2258  controlFoldingReshapes);
2259  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2260  patterns.getContext(), controlFoldingReshapes);
2261  patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2262  controlFoldingReshapes);
2263 }
2264 
2267  const ControlFusionFn &controlElementwiseOpsFusion) {
2268  auto *context = patterns.getContext();
2269  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2270  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2271  RemoveOutsDependency>(context);
2272  // Add the patterns that clean up dead operands and results.
2274 }
2275 
2278  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2279  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2280  CollapseLinalgDimensions<linalg::CopyOp>>(
2281  patterns.getContext(), controlCollapseDimensions);
2282 }
2283 
2284 //===---------------------------------------------------------------------===//
2285 // Passes
2286 //===---------------------------------------------------------------------===//
2287 
2288 namespace {
2289 
2290 /// Pass that fuses generic ops on tensors. Used only for testing.
2291 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2292 // patterns added here heavily depends on the cost function used. Having an
2293 // opinionated pass of this form is not recommended. Deprecate this pass in
2294 // favor of test passes that check the functionality of each of the patterns
2295 // added here individually.
2296 struct LinalgElementwiseOpFusionPass
2297  : public impl::LinalgElementwiseOpFusionPassBase<
2298  LinalgElementwiseOpFusionPass> {
2299  using impl::LinalgElementwiseOpFusionPassBase<
2300  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2301  void runOnOperation() override {
2302  Operation *op = getOperation();
2303  MLIRContext *context = op->getContext();
2304  RewritePatternSet patterns(context);
2305 
2306  // Add folding with reshape by expansion patterns.
2307  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2308  Operation *producer = fusedOperand->get().getDefiningOp();
2309  return producer && producer->hasOneUse();
2310  };
2311 
2312  // Add elementwise op fusion patterns.
2316 
2317  // General canonicalization patterns.
2318  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2319  GenericOp::getCanonicalizationPatterns(patterns, context);
2320  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2321  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2322  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2323  patterns);
2324 
2325  // Add constant folding patterns.
2327 
2328  // Use TopDownTraversal for compile time reasons
2329  GreedyRewriteConfig grc;
2330  grc.useTopDownTraversal = true;
2331  (void)applyPatternsGreedily(op, std::move(patterns), grc);
2332  }
2333 };
2334 
2335 } // namespace
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:571
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:642
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:632
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:238
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1777
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:242
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1808
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:836
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:379
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:502
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:504
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.