MLIR  22.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Support/LLVM.h"
29 #include <optional>
30 #include <utility>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 //===---------------------------------------------------------------------===//
41 // Methods and patterns that fuse elementwise `linalg.generic` operations.
42 //===---------------------------------------------------------------------===//
43 
44 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45 /// the `producer` to use in the fused operation given the indexing map of the
46 /// result of the producer in the consumer.
48  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49  AffineMap fusedConsumerArgIndexMap) {
50  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51  // from consumer loop -> consumer arg tensor index/producer result tensor
52  // index. The fused loop is same as the consumer loop. For each producer arg
53  // the indexing map to be computed is a map from consumer loop -> producer
54  // arg tensor index.
55  // producerResultIndexMap is a map from producer loop -> tensor index.
56  // Compute the inverse to get map from tensor index -> producer loop.
57  // The inverse is a map from producer result tensor index -> producer loop.
58  AffineMap invProducerResultIndexMap =
59  inversePermutation(producerResultIndexMap);
60  assert(invProducerResultIndexMap &&
61  "expected producer result indexing map to be invertible");
62 
63  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64  // argMap is a map from producer loop -> producer arg tensor index.
65  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66 
67  // Compose argMap with invProducerResultIndexMap to get a map from
68  // producer result tensor index -> producer arg tensor index.
69  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70 
71  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72  // consumer loop/ fused loop -> producer arg tensor index.
73  return t1.compose(fusedConsumerArgIndexMap);
74 }
75 
76 // Checks if the given operand can be dropped, and the remaining operands
77 // of the fused producer & consumer after the fusion can still compute the
78 // bounds of the op.
80  GenericOp producer, GenericOp consumer,
81  ArrayRef<OpOperand *> opOperandsToIgnore) {
82  SmallVector<AffineMap> indexingMaps;
83 
84  SmallVector<GenericOp> ops = {producer, consumer};
85  for (auto &op : ops) {
86  for (auto &opOperand : op->getOpOperands()) {
87  if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88  continue;
89  }
90  indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91  }
92  }
93  if (indexingMaps.empty()) {
94  // If there are no indexing maps, the operand can only be dropped
95  // if neither op has loops.
96  return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97  }
98 
99  // The concatanation of the remained indexing maps must be invertible, so
100  // the bounds of the op can be still computed after dropping the selected
101  // operand. inversePermutation returns an empty AffineMap in case the
102  // concatanated indexing maps are not invertible.
104  indexingMaps, producer.getContext())) != AffineMap();
105 }
106 
107 /// Returns a set of indices of the producer's results which would
108 /// be preserved after the fusion.
109 /// * There is a chance that the implementation of the transformation does not
110 /// agree with the result of this method. This function gives a prediction based
111 /// on an optimized fusion.
113  GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114  llvm::SmallDenseSet<int> preservedProducerResults;
115  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116 
117  // The fusedOperand will be removed during the fusion
118  opOperandsToIgnore.emplace_back(fusedOperand);
119 
120  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122  opOperandsToIgnore.emplace_back(outputOperand);
123  if (producer.payloadUsesValueFromOperand(outputOperand) ||
124  !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
125  opOperandsToIgnore) ||
126  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127  return user != consumer.getOperation();
128  })) {
129  preservedProducerResults.insert(producerResult.index());
130 
131  // In case the operand can't be dropped
132  (void)opOperandsToIgnore.pop_back_val();
133  }
134  }
135  return preservedProducerResults;
136 }
137 
138 /// Conditions for elementwise fusion of generic operations.
140  if (!fusedOperand)
141  return false;
142 
143  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145 
146  // Check producer and consumer are generic ops.
147  if (!producer || !consumer)
148  return false;
149 
150  // Consumer can have mixed semantics, just check operand itself has tensor
151  // type. Producer must have full tensor semantics to avoid potential
152  // aliasing between producer and consumer memrefs.
153  if (!producer.hasPureTensorSemantics() ||
154  !isa<RankedTensorType>(fusedOperand->get().getType()))
155  return false;
156 
157  // Verify that
158  // - the producer has all "parallel" iterator type.
159  if (producer.getNumParallelLoops() != producer.getNumLoops())
160  return false;
161 
162  // Only allow fusing the producer of an input operand for now.
163  // TODO: allow fusing the producer of an output operand.
164  if (!consumer.isDpsInput(fusedOperand))
165  return false;
166 
167  // Get the consumer index map. The number of results of the consumer index
168  // map must match the number of loops of the producer.
169  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171  return false;
172 
173  // Finally the index_map for the result must be invertible. For now just
174  // verify it is a permutation.
175  auto producerResult = cast<OpResult>(fusedOperand->get());
176  AffineMap producerResultIndexMap =
177  producer.getIndexingMapMatchingResult(producerResult);
178  if (!producerResultIndexMap.isPermutation())
179  return false;
180 
181  // Ensure that the fusion does not remove size information required to
182  // get the loop bounds. For non-reduction generics, this is trivially the
183  // case due to the output operand. For reductions, we need to check that after
184  // the fusion, each loop dimension has at least one input that defines it.
185  if ((consumer.getNumReductionLoops())) {
186  BitVector coveredDims(consumer.getNumLoops(), false);
187 
188  auto addToCoveredDims = [&](AffineMap map) {
189  for (auto result : map.getResults())
190  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
191  coveredDims[dimExpr.getPosition()] = true;
192  };
193 
194  for (auto pair :
195  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196  Value operand = std::get<0>(pair);
197  if (operand == fusedOperand->get())
198  continue;
199  AffineMap operandMap = std::get<1>(pair);
200  addToCoveredDims(operandMap);
201  }
202 
203  for (OpOperand *operand : producer.getDpsInputOperands()) {
204  AffineMap newIndexingMap =
206  operand, producerResultIndexMap, consumerIndexMap);
207  addToCoveredDims(newIndexingMap);
208  }
209  if (!coveredDims.all())
210  return false;
211  }
212 
213  return true;
214 }
215 
216 /// Generate the region of the fused tensor operation. The region of the fused
217 /// op must be empty.
219  RewriterBase &rewriter, GenericOp fusedOp,
220  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
221  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
223  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
224  // Build the region of the fused op.
225  Block &producerBlock = producer->getRegion(0).front();
226  Block &consumerBlock = consumer->getRegion(0).front();
227  OpBuilder::InsertionGuard guard(rewriter);
228  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
229  IRMapping mapper;
230 
231  // 2. Add an index operation for every fused loop dimension and use the
232  // `consumerToProducerLoopsMap` to map the producer indices.
233  if (producer.hasIndexSemantics()) {
234  // Add an index operation for every fused loop dimension.
235  unsigned numFusedOpLoops = fusedOp.getNumLoops();
236  SmallVector<Value> fusedIndices;
237  fusedIndices.reserve(numFusedOpLoops);
238  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239  std::back_inserter(fusedIndices), [&](uint64_t dim) {
240  return IndexOp::create(rewriter, producer.getLoc(), dim);
241  });
242  for (IndexOp indexOp :
243  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244  Value newIndex = affine::AffineApplyOp::create(
245  rewriter, producer.getLoc(),
246  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247  mapper.map(indexOp.getResult(), newIndex);
248  }
249  }
250  // TODO: allow fusing the producer of an output operand.
251  assert(consumer.isDpsInput(fusedOperand) &&
252  "expected producer of input operand");
253  // 3. Consumer input operands up to consumerIdx (exclusive).
254  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255  fusedOperand->getOperandNumber())) // input assumption.
256  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257 
258  // Replacing consumerIdx requires getting the cloned, yielded, value from
259  // the (cloned) producer block. This happens in step 9.
260 
261  // 4. Splice in producer's input operands.
262  for (BlockArgument bbArg :
263  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265 
266  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267  for (BlockArgument bbArg :
268  consumerBlock.getArguments()
269  .take_front(consumer.getNumDpsInputs())
270  .drop_front(fusedOperand->getOperandNumber() + 1))
271  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272 
273  // 6. All of the producer's output operands
274  for (const auto &bbArg : llvm::enumerate(
275  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276  if (!preservedProducerResults.count(bbArg.index()))
277  continue;
278  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279  bbArg.value().getLoc()));
280  }
281 
282  // 7. All of consumer's output operands.
283  for (BlockArgument bbArg :
284  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286 
287  // 8. Clone all producer operations except for the yield and index operations
288  // to the fused operation.
289  for (auto &op : producerBlock.without_terminator()) {
290  if (!isa<IndexOp>(op))
291  rewriter.clone(op, mapper);
292  }
293  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294  // forward the yield operand.
295  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296  unsigned producerResultNumber =
297  cast<OpResult>(fusedOperand->get()).getResultNumber();
298  Value replacement =
299  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300 
301  // Sanity checks, if replacement is not already in the mapper then it must be
302  // produced outside.
303  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304  if (auto bb = dyn_cast<BlockArgument>(replacement))
305  assert(bb.getOwner() != &producerBlock &&
306  "yielded block argument must have been mapped");
307  else
308  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309  "yielded value must have been mapped");
310  }
311  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
312  replacement);
313  // 10. Clone operations from the consumer to the fused op.
314  for (auto &op : consumerBlock.without_terminator())
315  rewriter.clone(op, mapper);
316 
317  // 11. Include the final yield (which is the remapped values for all the
318  // yield)
319  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320  SmallVector<Value> fusedYieldValues;
321  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322  consumerYieldOp.getNumOperands());
323  for (const auto &producerYieldVal :
324  llvm::enumerate(producerYieldOp.getOperands())) {
325  if (preservedProducerResults.count(producerYieldVal.index()))
326  fusedYieldValues.push_back(
327  mapper.lookupOrDefault(producerYieldVal.value()));
328  }
329  for (auto consumerYieldVal : consumerYieldOp.getOperands())
330  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331  YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
332 
333  // Sanity checks.
334  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335  "Ill-formed GenericOp region");
336 }
337 
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340  OpOperand *fusedOperand) {
341  assert(areElementwiseOpsFusable(fusedOperand) &&
342  "expected elementwise operation pre-conditions to pass");
343  auto producerResult = cast<OpResult>(fusedOperand->get());
344  auto producer = cast<GenericOp>(producerResult.getOwner());
345  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346  // TODO: allow fusing the producer of an output operand.
347  assert(consumer.isDpsInput(fusedOperand) &&
348  "expected producer of input operand");
349  /// Find the results of the producer that have uses outside of the consumer,
350  /// after the fusion.
351  llvm::SmallDenseSet<int> preservedProducerResults =
353  fusedOperand);
354 
355  // Compute the fused operands list and indexing maps.
356  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357  SmallVector<Type> fusedResultTypes;
358  SmallVector<AffineMap> fusedIndexMaps;
359  fusedInputOperands.reserve(producer.getNumDpsInputs() +
360  consumer.getNumDpsInputs());
361  fusedOutputOperands.reserve(preservedProducerResults.size() +
362  consumer.getNumDpsInits());
363  fusedResultTypes.reserve(preservedProducerResults.size() +
364  consumer.getNumDpsInits());
365  fusedIndexMaps.reserve(producer->getNumOperands() +
366  consumer->getNumOperands());
367  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369  auto consumerInputs = consumer.getDpsInputOperands();
370  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371  return operand == fusedOperand;
372  });
373  assert(it != consumerInputs.end() && "expected to find the consumer operand");
374  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375  fusedInputOperands.push_back(opOperand->get());
376  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377  }
378  // 4. Splice in producer's input operands/maps.
379  AffineMap producerResultIndexMap =
380  producer.getIndexingMapMatchingResult(producerResult);
381  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382  fusedInputOperands.push_back(opOperand->get());
383  // Compute indexing maps for the producer args in the fused operation.
385  opOperand, producerResultIndexMap,
386  consumer.getMatchingIndexingMap(fusedOperand));
387  fusedIndexMaps.push_back(map);
388  }
389  // 5. Remaining consumer's input operands/maps (drop past index
390  // `consumerIdx`).
391  for (OpOperand *opOperand :
392  llvm::make_range(std::next(it), consumerInputs.end())) {
393  fusedInputOperands.push_back(opOperand->get());
394  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395  }
396 
397  // 6. Collect all of the producer outputs.
398  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399  if (!preservedProducerResults.count(opOperand.index()))
400  continue;
401 
402  fusedOutputOperands.push_back(opOperand.value().get());
404  &opOperand.value(), producerResultIndexMap,
405  consumer.getMatchingIndexingMap(fusedOperand));
406  fusedIndexMaps.push_back(map);
407  fusedResultTypes.push_back(opOperand.value().get().getType());
408  }
409 
410  // 7. All of consumer's output operands (skip operands: added by the builder).
411  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412  fusedOutputOperands.push_back(opOperand.get());
413  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414  Type resultType = opOperand.get().getType();
415  if (!isa<MemRefType>(resultType))
416  fusedResultTypes.push_back(resultType);
417  }
418 
419  // Generate the fused op.
420  auto fusedOp = GenericOp::create(
421  rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423  consumer.getIteratorTypes(),
424  /*doc=*/nullptr,
425  /*library_call=*/nullptr);
426  if (!fusedOp.getShapesToLoopsMap()) {
427  // Fused op has invalid indexing maps. Typically this means something is off
428  // in the input, but going ahead here would result in verification errors.
429  // So cleanup and abort.
430  rewriter.eraseOp(fusedOp);
431  return rewriter.notifyMatchFailure(
432  fusedOp, "fused op failed loop bound computation check");
433  }
434 
435  // Construct an AffineMap from consumer loops to producer loops.
436  // consumer loop -> tensor index
437  AffineMap consumerResultIndexMap =
438  consumer.getMatchingIndexingMap(fusedOperand);
439  // tensor index -> producer loop
440  AffineMap invProducerResultIndexMap =
441  inversePermutation(producerResultIndexMap);
442  assert(invProducerResultIndexMap &&
443  "expected producer result indexig map to be invertible");
444  // consumer loop -> producer loop
445  AffineMap consumerToProducerLoopsMap =
446  invProducerResultIndexMap.compose(consumerResultIndexMap);
447 
449  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450  consumer.getNumLoops(), preservedProducerResults);
452  result.fusedOp = fusedOp;
453  int resultNum = 0;
454  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455  if (preservedProducerResults.count(index))
456  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457  for (auto consumerResult : consumer->getResults())
458  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459  return result;
460 }
461 
462 namespace {
463 /// Patterns to fuse a generic op, with the producer of its operands.
464 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465 public:
466  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467  PatternBenefit benefit = 1)
468  : OpRewritePattern<GenericOp>(context, benefit),
469  controlFn(std::move(fun)) {}
470 
471  LogicalResult matchAndRewrite(GenericOp genericOp,
472  PatternRewriter &rewriter) const override {
473  // Find the first operand that is defined by another generic op on tensors.
474  for (OpOperand &opOperand : genericOp->getOpOperands()) {
475  if (!areElementwiseOpsFusable(&opOperand))
476  continue;
477  if (!controlFn(&opOperand))
478  continue;
479 
480  Operation *producer = opOperand.get().getDefiningOp();
481 
482  // Find the producer of the operand.
483  FailureOr<ElementwiseOpFusionResult> fusionResult =
484  fuseElementwiseOps(rewriter, &opOperand);
485  if (failed(fusionResult))
486  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487 
488  // Perform the fusion.
489  for (auto [origVal, replacement] : fusionResult->replacements) {
490  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491  // Only replace consumer uses.
492  return use.get().getDefiningOp() != producer;
493  });
494  }
495  rewriter.eraseOp(genericOp);
496  return success();
497  }
498  return failure();
499  }
500 
501 private:
502  ControlFusionFn controlFn;
503 };
504 } // namespace
505 
506 //===---------------------------------------------------------------------===//
507 // Methods and patterns that fuse reshape ops with elementwise operations by
508 // expanding the dimensionality of the elementwise operations.
509 //===---------------------------------------------------------------------===//
510 
511 /// Conditions for folding a structured linalg operation with a reshape op by
512 /// expanding the iteration space dimensionality for tensor operations. These
513 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514 /// the following fusion pattern.
515 ///
516 /// Consider
517 ///
518 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
521 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524 ///
525 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
526 /// is increased to match the result (operand) of the tensor.expand_shape.
527 /// The indexing_map of the fused tensor in the `linalgOp` and the
528 /// reassociation map helps compute the indexing maps of the modified op.
529 /// For the above example, based on the reassociation map it
530 /// can be concluded that
531 ///
532 /// - The loop used to access the first dimension of the fused tensor is split
533 /// into two.
534 /// - The loop used to access the second dimension of the fused tensor is kept
535 /// as is.
536 /// - The loop used to access the third dimension of the fused tensor is split
537 /// into three.
538 ///
539 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540 /// op, then
541 ///
542 /// d0 -> e0, e1
543 /// d1 -> e2, e3, e4
544 /// d2 -> e5
545 ///
546 /// substituting this, the structured op can be rewritten as
547 ///
548 /// %d = linalg.generic ins(%0, %1 : )
549 /// indexing_maps =
550 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553 ///
554 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
555 /// to make it consistent
556 ///
557 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561 ///
562 /// The added reshapes are again expanding patterns, so they will get fused
563 /// with its producers if possible.
564 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565  OpOperand *fusableOpOperand) {
566  // Is fusable only if:
567  // - All the indexing maps for operands and results are projected
568  // permutations.
569  // - The fused tensor is not a scalar.
570  SmallVector<utils::IteratorType> iteratorTypes =
571  linalgOp.getIteratorTypesArray();
572  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573  return linalgOp.hasPureTensorSemantics() &&
574  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575  [](Attribute attr) {
576  return cast<AffineMapAttr>(attr)
577  .getValue()
578  .isProjectedPermutation();
579  }) &&
580  operandMap.getNumResults() > 0;
581 }
582 
583 namespace {
584 /// Information needed to expand a generic operation to fold the reshape with
585 /// it.
586 class ExpansionInfo {
587 public:
588  // Computes the mapping from original dimensions of the op to the dimensions
589  // of the expanded op given the `indexingMap` of the fused operand/result of
590  // the generic op, the `reassocationMaps` of the reshape op and the shape of
591  // the expanded op.
592  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593  ArrayRef<AffineMap> reassociationMaps,
594  ArrayRef<OpFoldResult> expandedShape,
595  PatternRewriter &rewriter);
596  unsigned getOrigOpNumDims() const { return reassociation.size(); }
597  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598  ReassociationIndicesRef getExpandedDims(unsigned i) const {
599  return reassociation[i];
600  }
601  ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602  return expandedShapeMap[i];
603  }
604  ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605 
606 private:
607  /// Reassociation from the dimensions in the original operation to the
608  /// dimension of the expanded operation.
609  SmallVector<ReassociationIndices> reassociation;
610  /// Mapping from extent of loops in the original operation, to the extent of
611  /// loops in the expanded operation.
612  SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613  /// Extent of the loop in the original operation.
614  SmallVector<OpFoldResult> originalLoopExtent;
615  unsigned expandedOpNumDims;
616 };
617 } // namespace
618 
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620  OpOperand *fusableOpOperand,
621  ArrayRef<AffineMap> reassociationMaps,
622  ArrayRef<OpFoldResult> expandedShape,
623  PatternRewriter &rewriter) {
624  if (reassociationMaps.empty())
625  return failure();
626  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627 
628  OpBuilder::InsertionGuard g(rewriter);
629  rewriter.setInsertionPoint(linalgOp);
630  originalLoopExtent = llvm::map_to_vector(
631  linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632  [](Range r) { return r.size; });
633 
634  reassociation.clear();
635  expandedShapeMap.clear();
636  // Compute the number of dimension in the expanded op that correspond to each
637  // dimension of the original op.
638  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639  expandedShapeMap.resize(fusedIndexMap.getNumDims());
640  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643  numExpandedDims[pos] = foldedDims.getNumResults();
644  ArrayRef<OpFoldResult> shape =
645  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646  expandedShapeMap[pos].assign(shape.begin(), shape.end());
647  }
648  // The remaining dimensions remain the same.
649  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650  if (expandedShapeMap[i].empty())
651  expandedShapeMap[i] = {originalLoopExtent[i]};
652 
653  // Compute reassociation map from the original op to the expanded op.
654  unsigned sum = 0;
655  reassociation.reserve(fusedIndexMap.getNumDims());
656  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658  reassociation.emplace_back(seq.begin(), seq.end());
659  sum += numFoldedDim.value();
660  }
661  expandedOpNumDims = sum;
662  return success();
663 }
664 
665 /// Return the indexing map to use in the expanded op for a given the
666 /// `indexingMap` of the original operation.
667 static AffineMap
669  const ExpansionInfo &expansionInfo) {
670  SmallVector<AffineExpr> newExprs;
671  for (AffineExpr expr : indexingMap.getResults()) {
672  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675  return builder.getAffineDimExpr(static_cast<unsigned>(v));
676  }));
677  newExprs.append(expandedExprs.begin(), expandedExprs.end());
678  }
679  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
680  indexingMap.getNumSymbols(), newExprs,
681  builder.getContext());
682 }
683 
684 /// Return the shape and type of the operand/result to use in the expanded op
685 /// given the type in the original op.
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687 getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688  const ExpansionInfo &expansionInfo) {
689  SmallVector<OpFoldResult> expandedShape;
690  for (AffineExpr expr : indexingMap.getResults()) {
691  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692  ArrayRef<OpFoldResult> dimExpansion =
693  expansionInfo.getExpandedShapeOfDim(dim);
694  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
695  }
696  SmallVector<int64_t> expandedStaticShape;
697  std::tie(expandedStaticShape, std::ignore) =
698  decomposeMixedValues(expandedShape);
699  return {expandedShape, RankedTensorType::get(expandedStaticShape,
700  originalType.getElementType())};
701 }
702 
703 /// Returns the reassociation maps to use in the `tensor.expand_shape`
704 /// operation to convert the operands of the original operation to operands of
705 /// the expanded operation. The same method is used to compute the
706 /// `tensor.collapse_shape` used to collapse the result of the expanded
707 /// op to get the value that can replace all uses of the results of the original
708 /// op.
711  const ExpansionInfo &expansionInfo) {
712  SmallVector<ReassociationIndices> reassociation;
713  unsigned numReshapeDims = 0;
714  for (AffineExpr expr : indexingMap.getResults()) {
715  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719  reassociation.emplace_back(std::move(indices));
720  numReshapeDims += numExpandedDims;
721  }
722  return reassociation;
723 }
724 
725 /// Update the body of an expanded linalg operation having index semantics. The
726 /// indices of the original operation need to be recovered by linearizing the
727 /// indices of the correspoding dimensions of the expanded operation. For now it
728 /// is assumed that the shapes of the expanded operation needed for
729 /// linearization are static.
731  Location loc, Region &fusedRegion,
732  const ExpansionInfo &expansionInfo) {
733  // Replace the original indices by the linearization of the expanded indices.
734  for (IndexOp indexOp :
735  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736  ArrayRef<int64_t> expandedDims =
737  expansionInfo.getExpandedDims(indexOp.getDim());
738  assert(!expandedDims.empty() && "expected valid expansion info");
739 
740  // Skip index operations that are not affected by the expansion.
741  if (expandedDims.size() == 1 &&
742  expandedDims.front() == (int64_t)indexOp.getDim())
743  continue;
744 
745  // Linearize the expanded indices of the original index dimension.
746  OpBuilder::InsertionGuard guard(rewriter);
747  rewriter.setInsertionPointAfter(indexOp);
748  ArrayRef<OpFoldResult> expandedDimsShape =
749  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750  SmallVector<Value> expandedIndices;
751  expandedIndices.reserve(expandedDims.size() - 1);
752  llvm::transform(
753  expandedDims.drop_front(), std::back_inserter(expandedIndices),
754  [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
755  OpFoldResult newIndex =
756  IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757  for (auto [expandedShape, expandedIndex] :
758  llvm::zip(expandedDimsShape, expandedIndices)) {
759  AffineExpr idx, acc, shape;
760  bindDims(rewriter.getContext(), idx, acc);
761  bindSymbols(rewriter.getContext(), shape);
763  rewriter, indexOp.getLoc(), idx + acc * shape,
764  ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765  }
766  Value newIndexVal =
767  getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768  rewriter.replaceOp(indexOp, newIndexVal);
769  }
770 }
771 
772 // Create an expanded transpose op.
773 // the reassociation map is already permuted hence we inverse permute and then
774 // flatten it. Then we inverse permute it again to get the final expanded
775 // transpose permutation. For example,
776 //
777 // permutation = [2, 0, 1]
778 // reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779 //
780 // inverse permutation = [1, 2, 0]
781 // applied to reassocation_map and then flattened becomes
782 // flatened permutation = [2, 3, 4, 5, 0, 1]
783 // final permuation is the inverse of the flattened permutation.
784 //
785 // Becomes
786 //
787 // permutation=[4, 5, 0, 1, 2, 3]
788 
790  TransposeOp transposeOp,
791  Value expandedInput, Value output,
792  ExpansionInfo &expansionInfo) {
793  SmallVector<int64_t> newPerm;
794  for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795  auto reassoc = expansionInfo.getExpandedDims(perm);
796  for (int64_t dim : reassoc) {
797  newPerm.push_back(dim);
798  }
799  }
800  return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
801  output, invertPermutationVector(newPerm));
802 }
803 
804 // Create an expanded generic op.
806  PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807  ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808  ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809  // The iterator types of the expanded op are all parallel.
810  SmallVector<utils::IteratorType> iteratorTypes(
811  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812 
813  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814  for (auto j : expansionInfo.getExpandedDims(i))
815  iteratorTypes[j] = type;
816 
817  Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818  expandedOpOperands, outputs,
819  expandedOpIndexingMaps, iteratorTypes);
820 
821  Region &fusedRegion = fused->getRegion(0);
822  Region &originalRegion = linalgOp->getRegion(0);
823  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824 
825  // Update the index accesses after the expansion.
826  updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827  expansionInfo);
828 
829  return fused;
830 }
831 
832 // Create an expanded fused op that retains the name for certain ops
833 // such as fill, copy and transpose and produce a generic op for
834 // rest of linalg ops.
835 static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836  TypeRange resultTypes,
837  ArrayRef<Value> expandedOpOperands,
838  ArrayRef<Value> outputs,
839  ArrayRef<AffineMap> expandedOpIndexingMaps,
840  ExpansionInfo &expansionInfo) {
841 
842  return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843  .Case<TransposeOp>([&](TransposeOp transposeOp) {
844  return createExpandedTransposeOp(rewriter, transposeOp,
845  expandedOpOperands[0], outputs[0],
846  expansionInfo);
847  })
848  .Case<FillOp, CopyOp>([&](Operation *op) {
849  return clone(rewriter, linalgOp, resultTypes,
850  llvm::to_vector(llvm::concat<Value>(
851  llvm::to_vector(expandedOpOperands),
852  llvm::to_vector(outputs))));
853  })
854  .Default([&](Operation *op) {
855  return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856  expandedOpOperands, outputs,
857  expansionInfo, expandedOpIndexingMaps);
858  });
859 }
860 
861 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863 /// that those conditions have been satisfied.
864 static std::optional<SmallVector<Value>>
865 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866  OpOperand *fusableOpOperand,
867  PatternRewriter &rewriter) {
868  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869  "preconditions for fuse operation failed");
870 
871  Location loc = linalgOp.getLoc();
872  SmallVector<OpFoldResult> expandedShape;
873  SmallVector<AffineMap, 4> reassociationIndices;
874  Value src;
875  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876  // Try to move the dynamic dimensions in output shape before the `linalgOp`
877  // to maintain SSA validity
879  rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880  return std::nullopt;
881 
882  expandedShape = expandingReshapeOp.getMixedOutputShape();
883  reassociationIndices = expandingReshapeOp.getReassociationMaps();
884  src = expandingReshapeOp.getSrc();
885  } else {
886  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887  if (!collapsingReshapeOp)
888  return std::nullopt;
889 
890  expandedShape = tensor::getMixedSizes(
891  rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892  reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893  src = collapsingReshapeOp.getSrc();
894  }
895 
896  ExpansionInfo expansionInfo;
897  if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898  reassociationIndices, expandedShape,
899  rewriter)))
900  return std::nullopt;
901 
902  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
905  }));
906 
907  // Set insertion point to the generic op.
908  OpBuilder::InsertionGuard g(rewriter);
909  rewriter.setInsertionPoint(linalgOp);
910 
911  SmallVector<Value> expandedOpOperands;
912  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914  if (opOperand == fusableOpOperand) {
915  expandedOpOperands.push_back(src);
916  continue;
917  }
918  if (auto opOperandType =
919  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921  SmallVector<OpFoldResult> expandedOperandShape;
922  RankedTensorType expandedOperandType;
923  std::tie(expandedOperandShape, expandedOperandType) =
924  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925  if (expandedOperandType != opOperand->get().getType()) {
926  // Reshape the operand to get the right type.
927  SmallVector<ReassociationIndices> reassociation =
928  getReassociationForExpansion(indexingMap, expansionInfo);
930  [&](const Twine &msg) {
931  return rewriter.notifyMatchFailure(linalgOp, msg);
932  },
933  opOperandType.getShape(), expandedOperandType.getShape(),
934  reassociation,
935  /*isExpandingReshape=*/true)))
936  return std::nullopt;
937  expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938  rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939  expandedOperandShape));
940  continue;
941  }
942  }
943  expandedOpOperands.push_back(opOperand->get());
944  }
945 
946  SmallVector<Value> outputs;
947  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950  SmallVector<OpFoldResult> expandedOutputShape;
951  RankedTensorType expandedOutputType;
952  std::tie(expandedOutputShape, expandedOutputType) =
953  getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954  if (expandedOutputType != opOperand.get().getType()) {
955  SmallVector<ReassociationIndices> reassociation =
956  getReassociationForExpansion(indexingMap, expansionInfo);
958  [&](const Twine &msg) {
959  return rewriter.notifyMatchFailure(linalgOp, msg);
960  },
961  opOperandType.getShape(), expandedOutputType.getShape(),
962  reassociation,
963  /*isExpandingReshape=*/true)))
964  return std::nullopt;
965  outputs.push_back(tensor::ExpandShapeOp::create(
966  rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967  expandedOutputShape));
968  } else {
969  outputs.push_back(opOperand.get());
970  }
971  }
972 
973  TypeRange resultTypes = ValueRange(outputs).getTypes();
974  Operation *fusedOp =
975  createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976  outputs, expandedOpIndexingMaps, expansionInfo);
977  // Reshape the result values to their original shape if this is a collapsing
978  // reshape folded into its consumer.
979  SmallVector<Value> resultVals;
980  for (OpResult opResult : linalgOp->getOpResults()) {
981  int64_t resultNumber = opResult.getResultNumber();
982  if (resultTypes[resultNumber] != opResult.getType()) {
983  SmallVector<ReassociationIndices> reassociation =
985  linalgOp.getMatchingIndexingMap(
986  linalgOp.getDpsInitOperand(resultNumber)),
987  expansionInfo);
988  resultVals.push_back(tensor::CollapseShapeOp::create(
989  rewriter, linalgOp.getLoc(), opResult.getType(),
990  fusedOp->getResult(resultNumber), reassociation));
991  } else {
992  resultVals.push_back(fusedOp->getResult(resultNumber));
993  }
994  }
995  // Assuming a single result.
996  return resultVals;
997 }
998 
999 namespace {
1000 
1001 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003 /// in the consumer is expanded.
1004 class FoldWithProducerReshapeOpByExpansion
1005  : public OpInterfaceRewritePattern<LinalgOp> {
1006 public:
1007  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008  ControlFusionFn foldReshapes,
1009  PatternBenefit benefit = 1)
1010  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011  controlFoldingReshapes(std::move(foldReshapes)) {}
1012 
1013  LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014  PatternRewriter &rewriter) const override {
1015  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016  tensor::CollapseShapeOp reshapeOp =
1017  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018  if (!reshapeOp)
1019  continue;
1020  // Fold only if
1021  // - The tensor reshape op is folding.
1022  // - All constraints of fusing with reshape by expansion are met.
1023  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024  (!controlFoldingReshapes(opOperand)))
1025  continue;
1026 
1027  std::optional<SmallVector<Value>> replacementValues =
1028  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029  if (!replacementValues)
1030  return failure();
1031  rewriter.replaceOp(linalgOp, *replacementValues);
1032  return success();
1033  }
1034  return failure();
1035  }
1036 
1037 private:
1038  ControlFusionFn controlFoldingReshapes;
1039 };
1040 
1041 class FoldPadWithProducerReshapeOpByExpansion
1042  : public OpRewritePattern<tensor::PadOp> {
1043 public:
1044  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1045  ControlFusionFn foldReshapes,
1046  PatternBenefit benefit = 1)
1047  : OpRewritePattern<tensor::PadOp>(context, benefit),
1048  controlFoldingReshapes(std::move(foldReshapes)) {}
1049 
1050  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051  PatternRewriter &rewriter) const override {
1052  tensor::CollapseShapeOp reshapeOp =
1053  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1054  if (!reshapeOp)
1055  return failure();
1056  if (!reshapeOp->hasOneUse())
1057  return failure();
1058 
1059  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1060  return rewriter.notifyMatchFailure(padOp,
1061  "fusion blocked by control function");
1062  }
1063 
1064  ArrayRef<int64_t> low = padOp.getStaticLow();
1065  ArrayRef<int64_t> high = padOp.getStaticHigh();
1066  SmallVector<ReassociationIndices> reassociations =
1067  reshapeOp.getReassociationIndices();
1068 
1069  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070  if (reInd.size() != 1 && (l != 0 || h != 0))
1071  return failure();
1072  }
1073 
1074  SmallVector<OpFoldResult> newLow, newHigh;
1075  RankedTensorType expandedType = reshapeOp.getSrcType();
1076  RankedTensorType paddedType = padOp.getResultType();
1077  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1078  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079  if (reInd.size() == 1) {
1080  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081  }
1082  for (size_t i = 0; i < reInd.size(); ++i) {
1083  newLow.push_back(padOp.getMixedLowPad()[idx]);
1084  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1085  }
1086  }
1087 
1088  Location loc = padOp->getLoc();
1089  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090  auto newPadOp = tensor::PadOp::create(
1091  rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092  padOp.getConstantPaddingValue(), padOp.getNofold());
1093 
1094  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1096 
1097  return success();
1098  }
1099 
1100 private:
1101  ControlFusionFn controlFoldingReshapes;
1102 };
1103 
1104 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1105 /// by expanding the dimensionality of the loop in the producer op.
1106 struct FoldReshapeWithGenericOpByExpansion
1107  : public OpRewritePattern<tensor::ExpandShapeOp> {
1108 
1109  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1110  ControlFusionFn foldReshapes,
1111  PatternBenefit benefit = 1)
1112  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113  controlFoldingReshapes(std::move(foldReshapes)) {}
1114 
1115  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1116  PatternRewriter &rewriter) const override {
1117  // Fold only if all constraints of fusing with reshape by expansion are met.
1118  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119  if (!producerResult) {
1120  return rewriter.notifyMatchFailure(reshapeOp,
1121  "source not produced by an operation");
1122  }
1123 
1124  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1125  if (!producer) {
1126  return rewriter.notifyMatchFailure(reshapeOp,
1127  "producer not a generic op");
1128  }
1129 
1131  producer,
1132  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1133  return rewriter.notifyMatchFailure(
1134  reshapeOp, "failed preconditions of fusion with producer generic op");
1135  }
1136 
1137  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1138  return rewriter.notifyMatchFailure(reshapeOp,
1139  "fusion blocked by control function");
1140  }
1141 
1142  std::optional<SmallVector<Value>> replacementValues =
1144  producer, reshapeOp,
1145  producer.getDpsInitOperand(producerResult.getResultNumber()),
1146  rewriter);
1147  if (!replacementValues) {
1148  return rewriter.notifyMatchFailure(reshapeOp,
1149  "fusion by expansion failed");
1150  }
1151 
1152  // Find the replacement for the reshape op. Since the replacements have the
1153  // same type as the returns of the original generic op, the consumer reshape
1154  // op can be replaced by the source of the collapse_shape op that defines
1155  // the replacement.
1156  Value reshapeReplacement =
1157  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158  .getResultNumber()];
1159  if (auto collapseOp =
1160  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1161  reshapeReplacement = collapseOp.getSrc();
1162  }
1163  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1164  rewriter.replaceOp(producer, *replacementValues);
1165  return success();
1166  }
1167 
1168 private:
1169  ControlFusionFn controlFoldingReshapes;
1170 };
1171 } // namespace
1172 
1173 //===---------------------------------------------------------------------===//
1174 // Methods and patterns to fuse reshape with linalg.generic operations by
1175 // contraction of dimensions.
1176 //===---------------------------------------------------------------------===//
1177 
1178 /// For a given list of indices in the range of the `indexingMap` that are
1179 /// folded, return the indices of the corresponding domain. Return
1180 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1181 /// reassociation are distinct.
1182 static ReassociationIndices
1184  ReassociationIndicesRef rangeReassociation) {
1185  assert(indexingMap.isProjectedPermutation() &&
1186  "expected projected permutation");
1187 
1188  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1189  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1191  }));
1192  // The projected permutation semantics ensures that there is no repetition of
1193  // the domain indices.
1194  return domainReassociation;
1195 }
1196 
1197 /// For a given `dimSequence`, check if the sequence is conserved in the
1198 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1199 /// Non-existence of the sequence returns true as well.
1201  ReassociationIndicesRef dimSequence) {
1202  assert(!dimSequence.empty() &&
1203  "expected non-empty list for dimension sequence");
1204  assert(indexingMap.isProjectedPermutation() &&
1205  "expected indexing map to be projected permutation");
1206 
1207  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208  sequenceElements.insert_range(dimSequence);
1209 
1210  unsigned dimSequenceStart = dimSequence[0];
1211  for (const auto &expr : enumerate(indexingMap.getResults())) {
1212  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1213  // 1. Check if this start of the sequence.
1214  if (dimInMapStart == dimSequenceStart) {
1215  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1216  return false;
1217  // 1a. Check if sequence is preserved.
1218  for (const auto &dimInSequence : enumerate(dimSequence)) {
1219  unsigned dimInMap =
1220  cast<AffineDimExpr>(
1221  indexingMap.getResult(expr.index() + dimInSequence.index()))
1222  .getPosition();
1223  if (dimInMap != dimInSequence.value())
1224  return false;
1225  }
1226  // Found the sequence. Projected permutation
1227  // enforces that all AffineDimExprs in the result are unique, so no
1228  // further checks are needed.
1229  return true;
1230  }
1231  // 2. If position in the expr (which is of type AffineDimExpr) is part
1232  // of sequence, return false here. This implies the entire sequence does not
1233  // exist in the indexing map.
1234  if (sequenceElements.count(dimInMapStart))
1235  return false;
1236  }
1237  // 3. No element of sequence found. Return true.
1238  return true;
1239 }
1240 
1243  return llvm::all_of(maps, [&](AffineMap map) {
1244  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1245  return isDimSequencePreserved(map, dimSequence);
1246  });
1247  });
1248 }
1249 
1250 // Return the list of dimensions of the iteration domain that can be
1251 // collapsed to allow for fusion with the a producer that is an expand_shape
1252 // operation. If all dimensions created by expansion can be collapsed in the
1253 // iteration space then the reshape is defunct.
1254 //
1255 // Example:
1256 //
1257 // ```mlir
1258 // #map = affine_map<(d0, d1) -> (d0, d1)>
1259 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1260 // %2 = tensor.empty [..] : tensor<?x4xf32>
1261 // %3 = linalg.generic {
1262 // indexing_maps = [#map, #map],
1263 // iterator_types = ["parallel" ,"parallel"]}
1264 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1265 // ```
1266 //
1267 // can be fused by collapsing the dimensions of the iteration space.
1268 //
1269 // ```mlir
1270 // #map = affine_map<(d0) -> (d0)>
1271 // %2 = tensor.empty [..] : tensor<?xf32>
1272 // %3 = linalg.generic {
1273 // indexing_maps = [#map, #map],
1274 // iterator_types = ["parallel"]}
1275 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1276 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1277 // ```
1278 //
1279 // In the following example,
1280 //
1281 // ```mlir
1282 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1283 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1284 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1285 // %2 = tensor.empty [..] : tensor<4x?xf32>
1286 // %2 = linalg.generic {
1287 // indexing_maps = [#map0, #map1],
1288 // iterator_types = ["parallel" ,"parallel"]}
1289 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1290 // ```
1291 //
1292 // the reshape cannot be fused with the generic op by collapsing the op
1293 // dimensions since the indexing maps will have to contain mods and divs
1294 // to preserve the accesses pattern. When no dimensions of the iteration
1295 // space are collapsable and empty vector is returned.
1297 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1298  ArrayRef<ReassociationIndices> reassociation) {
1299  // Some basic checks for this fusion to be valid.
1300  if (!genericOp.hasPureTensorSemantics())
1301  return {};
1302 
1303  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1304  return map.isProjectedPermutation();
1305  })) {
1306  return {};
1307  }
1308 
1309  // Compute all the loops with the reduction iterator types.
1310  SmallVector<unsigned> reductionDims;
1311  genericOp.getReductionDims(reductionDims);
1312 
1313  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315  auto iteratorTypes = genericOp.getIteratorTypesArray();
1316  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1317  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1318  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1319 
1320  // Ignore dims that are not folded.
1321  if (foldedRangeDims.size() == 1)
1322  continue;
1323 
1324  ReassociationIndices foldedIterationSpaceDims =
1325  getDomainReassociation(indexingMap, foldedRangeDims);
1326 
1327  // Check that the folded iteration dims do not contain already processed
1328  // dims.
1329  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330  return processedIterationDims.count(dim);
1331  }))
1332  continue;
1333 
1334  // Check that all folded iterator types are all parallel or all reductions.
1335  utils::IteratorType startIteratorType =
1336  iteratorTypes[foldedIterationSpaceDims[0]];
1337  if (!isParallelIterator(startIteratorType) &&
1338  !isReductionIterator(startIteratorType))
1339  continue;
1340  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341  return iteratorTypes[dim] != startIteratorType;
1342  }))
1343  continue;
1344 
1345  // If the folded dimensions correspond to a "reduction" iterator type,
1346  // the folded dimensions need to be "in-order". Strictly speaking this is
1347  // not necessary, for reductions that are associative and commutative, but
1348  // using a more strict definition of reduction for now.
1349  if (isReductionIterator(startIteratorType)) {
1350  bool isContiguous = false;
1351  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1352  // Move window in `reductionDims` to start of the folded iteration dims.
1353  if (startDim.value() != foldedIterationSpaceDims[0])
1354  continue;
1355  // If sizes doesnt match, trivial not contiguous. This condition should
1356  // not be hit.
1357  if (startDim.index() + foldedIterationSpaceDims.size() >
1358  reductionDims.size())
1359  break;
1360  // Check that the contiguity is maintained.
1361  isContiguous = true;
1362  for (const auto &foldedDim :
1363  llvm::enumerate(foldedIterationSpaceDims)) {
1364  if (reductionDims[foldedDim.index() + startDim.index()] !=
1365  foldedDim.value()) {
1366  isContiguous = false;
1367  break;
1368  }
1369  }
1370  break;
1371  }
1372  if (!isContiguous)
1373  continue;
1374  }
1375 
1376  // Check that the sequence is preserved in all indexing maps.
1377  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1378  [&](AffineMap indexingMap) {
1379  return !isDimSequencePreserved(indexingMap,
1380  foldedIterationSpaceDims);
1381  }))
1382  continue;
1383 
1384  processedIterationDims.insert_range(foldedIterationSpaceDims);
1385  iterationSpaceReassociation.emplace_back(
1386  std::move(foldedIterationSpaceDims));
1387  }
1388 
1389  return iterationSpaceReassociation;
1390 }
1391 
1392 /// Helper class to carry state while collapsing the `linalg.generic` op.
1393 namespace {
1394 class CollapsingInfo {
1395 public:
1396  LogicalResult initialize(unsigned origNumLoops,
1397  ArrayRef<ReassociationIndices> foldedIterationDims) {
1398  llvm::SmallDenseSet<int64_t, 4> processedDims;
1399  // Find all the dims that are folded.
1400  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1401  if (foldedIterationDim.empty())
1402  continue;
1403  // If the folded dims contain dims already folded, that's illegal
1404  // specification. Repetition within a list is also illegal.
1405  for (auto dim : foldedIterationDim) {
1406  if (dim >= origNumLoops)
1407  return failure();
1408  if (processedDims.count(dim))
1409  return failure();
1410  processedDims.insert(dim);
1411  }
1412  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413  foldedIterationDim.end());
1414  }
1415  if (processedDims.size() > origNumLoops)
1416  return failure();
1417 
1418  // Add all the preserved dims of the original op as single
1419  // elements to `collapsedOpToOrigOpIterationDim`.
1420  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421  if (processedDims.count(dim))
1422  continue;
1423  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1424  }
1425 
1426  llvm::sort(collapsedOpToOrigOpIterationDim,
1428  return lhs[0] < rhs[0];
1429  });
1430  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431  for (const auto &foldedDims :
1432  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1433  for (const auto &dim : enumerate(foldedDims.value()))
1434  origOpToCollapsedOpIterationDim[dim.value()] =
1435  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1436  }
1437  return success();
1438  }
1439 
1440  /// Return mapping from collapsed loop domain to original loop domain.
1441  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1442  return collapsedOpToOrigOpIterationDim;
1443  }
1444 
1445  /// Return mapping from original loop domain to collapsed loop domain. The
1446  /// mapping is a pair. First value is the dimension in the collapsed loop that
1447  /// the original loop is mapped to. Second is the relative position in folded
1448  /// list of this domain. For example if the original loop domain is 3D, and
1449  /// the collapsed loop domain is folding all of it, i.e.
1450  ///
1451  /// ```
1452  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1453  /// ```
1454  ///
1455  /// then
1456  ///
1457  /// ```
1458  /// origOpToCollapsedOpMapping[0] = {0, 0};
1459  /// origOpToCollapsedOpMapping[1] = {0, 1};
1460  /// origOpToCollapsedOpMapping[2] = {0, 2};
1461  /// origOpToCollapsedOpMapping[3] = {1, 0};
1462  /// origOpToCollapsedOpMapping[4] = {1, 1};
1463  /// ```
1464  ///
1465  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1466  return origOpToCollapsedOpIterationDim;
1467  }
1468 
1469  /// Return the collapsed op iteration domain rank.
1470  unsigned getCollapsedOpIterationRank() const {
1471  return collapsedOpToOrigOpIterationDim.size();
1472  }
1473 
1474 private:
1475  /// Map from the iteration domain index in collapsed op to the iteration
1476  /// domain indices in the original op.
1477  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1478 
1479  /// Map from iteration domain index in the original op to the iteration domain
1480  /// index in the collapsed op.
1481  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1482 };
1483 } // namespace
1484 
1485 /// Get the iterator types for the collapsed operation given the original
1486 /// iterator types and collapsed dimensions.
1489  const CollapsingInfo &collapsingInfo) {
1490  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1491  for (ReassociationIndicesRef foldedIterDims :
1492  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493  assert(!foldedIterDims.empty() &&
1494  "reassociation indices expected to have non-empty sets");
1495  // Just pick the iterator type of the first folded dim. Pre-condition checks
1496  // expected to have checked that iterator types of all folded dimensions are
1497  // the same.
1498  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1499  }
1500  return collapsedIteratorTypes;
1501 }
1502 
1503 /// Compute the indexing map in the collapsed op that corresponds to the given
1504 /// `indexingMap` of the original operation.
1505 static AffineMap
1507  const CollapsingInfo &collapsingInfo) {
1508  MLIRContext *context = indexingMap.getContext();
1509  assert(indexingMap.isProjectedPermutation() &&
1510  "expected indexing map to be projected permutation");
1511  SmallVector<AffineExpr> resultExprs;
1512  auto origOpToCollapsedOpMapping =
1513  collapsingInfo.getOrigOpToCollapsedOpMapping();
1514  for (auto expr : indexingMap.getResults()) {
1515  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1516  // If the dim is not the first of the collapsed dim, do nothing.
1517  if (origOpToCollapsedOpMapping[dim].second != 0)
1518  continue;
1519  // The next n-dims are guaranteed to be collapsed. So just use the
1520  // iteration dimension of the collapsed op.
1521  resultExprs.push_back(
1522  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1523  }
1524  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525  resultExprs, context);
1526 }
1527 
1528 /// Return the `reassociation` indices to use to collapse the operand when the
1529 /// iteration space of a generic op is collapsed.
1532  const CollapsingInfo &collapsingInfo) {
1533  unsigned counter = 0;
1534  SmallVector<ReassociationIndices> operandReassociation;
1535  auto origOpToCollapsedOpMapping =
1536  collapsingInfo.getOrigOpToCollapsedOpMapping();
1537  auto collapsedOpToOrigOpMapping =
1538  collapsingInfo.getCollapsedOpToOrigOpMapping();
1539  while (counter < indexingMap.getNumResults()) {
1540  unsigned dim =
1541  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1542  // This is the start of a collapsed dimensions of the iteration that
1543  // is gauranteed to be preserved in the indexing map. The number of folded
1544  // dims is obtained from the collapsed op to original op mapping.
1545  unsigned numFoldedDims =
1546  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1547  .size();
1548  if (origOpToCollapsedOpMapping[dim].second == 0) {
1549  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550  operandReassociation.emplace_back(range.begin(), range.end());
1551  }
1552  counter += numFoldedDims;
1553  }
1554  return operandReassociation;
1555 }
1556 
1557 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1558 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1559  OpOperand *opOperand,
1560  const CollapsingInfo &collapsingInfo,
1561  OpBuilder &builder) {
1562  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1563  SmallVector<ReassociationIndices> operandReassociation =
1564  getOperandReassociation(indexingMap, collapsingInfo);
1565 
1566  // If the number of entries in the reassociation for the operand is same as
1567  // the number of results of the indexing map, then nothing to do for this
1568  // operand.
1569  Value operand = opOperand->get();
1570  if (operandReassociation.size() == indexingMap.getNumResults())
1571  return operand;
1572 
1573  // Insert a reshape to collapse the dimensions.
1574  if (isa<MemRefType>(operand.getType())) {
1575  return memref::CollapseShapeOp::create(builder, loc, operand,
1576  operandReassociation)
1577  .getResult();
1578  }
1579  return tensor::CollapseShapeOp::create(builder, loc, operand,
1580  operandReassociation)
1581  .getResult();
1582 }
1583 
1584 /// Modify the `linalg.index` operations in the original generic op, to its
1585 /// value in the collapsed operation.
1587  Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1588  ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1589  OpBuilder::InsertionGuard g(rewriter);
1590  rewriter.setInsertionPointToStart(block);
1591 
1592  // Collect all the original index ops.
1593  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1594 
1595  // For each folded dimension list resolve the original induction variable
1596  // values in terms of the folded dimension induction variable.
1597  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1598  // can be inverted to
1599  // i2 = i_{folded} % d2
1600  // i1 = (i_{folded} / d2) % d1
1601  // i0 = i_{folded} / (d1 * d2)
1602  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1603  for (auto foldedDims :
1604  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1605  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1606  Value newIndexVal =
1607  linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1608  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1609  Value loopDim =
1610  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1611  indexReplacementVals[dim] =
1612  rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1613  newIndexVal =
1614  rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1615  }
1616  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1617  }
1618 
1619  for (auto indexOp : indexOps) {
1620  auto dim = indexOp.getDim();
1621  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1622  }
1623 }
1624 
1626  const CollapsingInfo &collapsingInfo,
1627  RewriterBase &rewriter,
1628  SmallVectorImpl<Value> &inputOperands,
1629  SmallVectorImpl<Value> &outputOperands,
1630  SmallVectorImpl<Type> &resultTypes) {
1631  Location loc = op->getLoc();
1632  inputOperands =
1633  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1634  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1635  rewriter);
1636  });
1637 
1638  // Get the output operands and result types.
1639  resultTypes.reserve(op.getNumDpsInits());
1640  outputOperands.reserve(op.getNumDpsInits());
1641  for (OpOperand &output : op.getDpsInitsMutable()) {
1642  Value newOutput =
1643  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1644  outputOperands.push_back(newOutput);
1645  // If the op has "buffer semantics", then the init operands are ranked
1646  // memrefs and the op has no results.
1647  if (!op.hasPureBufferSemantics())
1648  resultTypes.push_back(newOutput.getType());
1649  }
1650 }
1651 
1652 /// Clone a `LinalgOp` to a collapsed version of same name
1653 template <typename OpTy>
1654 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1655  const CollapsingInfo &collapsingInfo) {
1656  return nullptr;
1657 }
1658 
1659 /// Collapse any `LinalgOp` that does not require any specialization such as
1660 /// indexing_maps, iterator_types, etc.
1661 template <>
1662 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1663  const CollapsingInfo &collapsingInfo) {
1664  SmallVector<Value> inputOperands, outputOperands;
1665  SmallVector<Type> resultTypes;
1666  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1667  outputOperands, resultTypes);
1668 
1669  return clone(
1670  rewriter, origOp, resultTypes,
1671  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1672 }
1673 
1674 /// Collapse a `GenericOp`
1675 template <>
1677  GenericOp origOp,
1678  const CollapsingInfo &collapsingInfo) {
1679  SmallVector<Value> inputOperands, outputOperands;
1680  SmallVector<Type> resultTypes;
1681  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1682  outputOperands, resultTypes);
1683  SmallVector<AffineMap> indexingMaps(
1684  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1685  return getCollapsedOpIndexingMap(map, collapsingInfo);
1686  }));
1687 
1689  origOp.getIteratorTypesArray(), collapsingInfo));
1690 
1691  GenericOp collapsedOp = linalg::GenericOp::create(
1692  rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1693  indexingMaps, iteratorTypes,
1694  [](OpBuilder &builder, Location loc, ValueRange args) {});
1695  Block *origOpBlock = &origOp->getRegion(0).front();
1696  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1697  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1698  collapsedOpBlock->getArguments());
1699  return collapsedOp;
1700 }
1701 
1702 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1703  RewriterBase &rewriter) {
1704  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1705  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1706  } else {
1707  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1708  }
1709 }
1710 
1711 /// Implementation of fusion with reshape operation by collapsing dimensions.
1712 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1713  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1714  RewriterBase &rewriter) {
1715  // Bail on trivial no-op cases.
1716  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1717  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1718  return foldedDims.size() <= 1;
1719  }))
1720  return failure();
1721 
1722  CollapsingInfo collapsingInfo;
1723  if (failed(
1724  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1725  return rewriter.notifyMatchFailure(
1726  op, "illegal to collapse specified dimensions");
1727  }
1728 
1729  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1730  if (hasPureBufferSemantics &&
1731  !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1732  MemRefType memRefToCollapse =
1733  dyn_cast<MemRefType>(opOperand.get().getType());
1734  if (!memRefToCollapse)
1735  return true;
1736 
1737  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1738  SmallVector<ReassociationIndices> operandReassociation =
1739  getOperandReassociation(indexingMap, collapsingInfo);
1740  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1741  memRefToCollapse, operandReassociation);
1742  }))
1743  return rewriter.notifyMatchFailure(op,
1744  "memref is not guaranteed collapsible");
1745 
1746  // Bail on non-canonical ranges.
1747  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1748  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1749  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1750  return cast<IntegerAttr>(attr).getInt() == value;
1751  llvm::APInt actual;
1752  return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1753  actual.getSExtValue() == value;
1754  };
1755  if (!llvm::all_of(loopRanges, [&](Range range) {
1756  return opFoldIsConstantValue(range.offset, 0) &&
1757  opFoldIsConstantValue(range.stride, 1);
1758  })) {
1759  return rewriter.notifyMatchFailure(
1760  op, "expected all loop ranges to have zero start and unit stride");
1761  }
1762 
1763  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1764 
1765  Location loc = op->getLoc();
1766  SmallVector<OpFoldResult> loopBound =
1767  llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1768 
1769  if (collapsedOp.hasIndexSemantics()) {
1770  // Collect the loop range of the generic op.
1771  OpBuilder::InsertionGuard g(rewriter);
1772  rewriter.setInsertionPoint(collapsedOp);
1773  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1774  collapsingInfo, loopBound, rewriter);
1775  }
1776 
1777  // Insert expanding reshape for the result to get back the original result
1778  // type.
1779  SmallVector<Value> results;
1780  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1781  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1782  auto originalResultType =
1783  cast<ShapedType>(originalResult.value().getType());
1784  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1785  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1786  AffineMap indexingMap =
1787  op.getIndexingMapMatchingResult(originalResult.value());
1788  SmallVector<ReassociationIndices> reassociation =
1789  getOperandReassociation(indexingMap, collapsingInfo);
1790  assert(
1791  indexingMap.isProjectedPermutation() &&
1792  "Expected indexing map to be a projected permutation for collapsing");
1793  SmallVector<OpFoldResult> resultShape =
1794  applyPermutationMap(indexingMap, ArrayRef(loopBound));
1795  Value result;
1796  if (isa<MemRefType>(collapsedOpResult.getType())) {
1797  MemRefType expandShapeResultType = MemRefType::get(
1798  originalResultType.getShape(), originalResultType.getElementType());
1799  result = memref::ExpandShapeOp::create(
1800  rewriter, loc, expandShapeResultType, collapsedOpResult,
1801  reassociation, resultShape);
1802  } else {
1803  result = tensor::ExpandShapeOp::create(
1804  rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1805  resultShape);
1806  }
1807  results.push_back(result);
1808  } else {
1809  results.push_back(collapsedOpResult);
1810  }
1811  }
1812  return CollapseResult{results, collapsedOp};
1813 }
1814 
1815 namespace {
1816 
1817 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1818 /// contracting dimensions of the loop.
1819 class FoldWithProducerReshapeOpByCollapsing
1820  : public OpRewritePattern<GenericOp> {
1821 public:
1822  // TODO : support fusion with all linalg ops, not just generic.
1823  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1824  ControlFusionFn foldReshapes,
1825  PatternBenefit benefit = 1)
1826  : OpRewritePattern<GenericOp>(context, benefit),
1827  controlFoldingReshapes(std::move(foldReshapes)) {}
1828 
1829  LogicalResult matchAndRewrite(GenericOp genericOp,
1830  PatternRewriter &rewriter) const override {
1831  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1832  tensor::ExpandShapeOp reshapeOp =
1833  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1834  if (!reshapeOp)
1835  continue;
1836 
1837  SmallVector<ReassociationIndices> collapsableIterationDims =
1838  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1839  reshapeOp.getReassociationIndices());
1840  if (collapsableIterationDims.empty() ||
1841  !controlFoldingReshapes(&opOperand)) {
1842  continue;
1843  }
1844 
1845  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1846  genericOp, collapsableIterationDims, rewriter);
1847  if (!collapseResult) {
1848  return rewriter.notifyMatchFailure(
1849  genericOp, "failed to do the fusion by collapsing transformation");
1850  }
1851 
1852  rewriter.replaceOp(genericOp, collapseResult->results);
1853  return success();
1854  }
1855  return failure();
1856  }
1857 
1858 private:
1859  ControlFusionFn controlFoldingReshapes;
1860 };
1861 
1862 /// Pattern to fold a tensor.collapse_shape op with its producer generic op
1863 /// by expanding the dimensionality of the loop in the producer op.
1864 struct FoldReshapeWithGenericOpByCollapsing
1865  : public OpRewritePattern<tensor::CollapseShapeOp> {
1866 
1867  FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1868  ControlFusionFn foldReshapes,
1869  PatternBenefit benefit = 1)
1870  : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1871  controlFoldingReshapes(std::move(foldReshapes)) {}
1872 
1873  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1874  PatternRewriter &rewriter) const override {
1875  // Fold only if all constraints of fusing with reshape by collapsing are
1876  // met.
1877  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1878  if (!producerResult) {
1879  return rewriter.notifyMatchFailure(reshapeOp,
1880  "source not produced by an operation");
1881  }
1882 
1883  // TODO : support fusion with all linalg producers, not just generic.
1884  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1885  if (!producer) {
1886  return rewriter.notifyMatchFailure(reshapeOp,
1887  "producer not a generic op");
1888  }
1889 
1890  SmallVector<ReassociationIndices> collapsableIterationDims =
1892  producer,
1893  producer.getDpsInitOperand(producerResult.getResultNumber()),
1894  reshapeOp.getReassociationIndices());
1895  if (collapsableIterationDims.empty()) {
1896  return rewriter.notifyMatchFailure(
1897  reshapeOp, "failed preconditions of fusion with producer generic op");
1898  }
1899 
1900  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1901  return rewriter.notifyMatchFailure(reshapeOp,
1902  "fusion blocked by control function");
1903  }
1904 
1905  // Set the insertion point after `producer` because there could be uses
1906  // of `producer` between it and the `tensor.collapse_shape` op.
1907  rewriter.setInsertionPointAfter(producer);
1908  std::optional<CollapseResult> collapseResult =
1909  collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1910  if (!collapseResult) {
1911  return rewriter.notifyMatchFailure(
1912  producer, "failed to do the fusion by collapsing transformation");
1913  }
1914 
1915  rewriter.replaceOp(producer, collapseResult->results);
1916  return success();
1917  }
1918 
1919 private:
1920  ControlFusionFn controlFoldingReshapes;
1921 };
1922 
1923 class FoldPadWithProducerReshapeOpByCollapsing
1924  : public OpRewritePattern<tensor::PadOp> {
1925 public:
1926  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1927  ControlFusionFn foldReshapes,
1928  PatternBenefit benefit = 1)
1929  : OpRewritePattern<tensor::PadOp>(context, benefit),
1930  controlFoldingReshapes(std::move(foldReshapes)) {}
1931 
1932  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1933  PatternRewriter &rewriter) const override {
1934  tensor::ExpandShapeOp reshapeOp =
1935  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1936  if (!reshapeOp)
1937  return failure();
1938  if (!reshapeOp->hasOneUse())
1939  return failure();
1940 
1941  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1942  return rewriter.notifyMatchFailure(padOp,
1943  "fusion blocked by control function");
1944  }
1945 
1946  ArrayRef<int64_t> low = padOp.getStaticLow();
1947  ArrayRef<int64_t> high = padOp.getStaticHigh();
1948  SmallVector<ReassociationIndices> reassociations =
1949  reshapeOp.getReassociationIndices();
1950 
1951  for (auto reInd : reassociations) {
1952  if (reInd.size() == 1)
1953  continue;
1954  if (llvm::any_of(reInd, [&](int64_t ind) {
1955  return low[ind] != 0 || high[ind] != 0;
1956  })) {
1957  return failure();
1958  }
1959  }
1960 
1961  SmallVector<OpFoldResult> newLow, newHigh;
1962  RankedTensorType collapsedType = reshapeOp.getSrcType();
1963  RankedTensorType paddedType = padOp.getResultType();
1964  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1965  SmallVector<OpFoldResult> expandedPaddedSizes(
1966  getMixedValues(reshapeOp.getStaticOutputShape(),
1967  reshapeOp.getOutputShape(), rewriter));
1968  AffineExpr d0, d1, d2;
1969  bindDims(rewriter.getContext(), d0, d1, d2);
1970  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1971  Location loc = reshapeOp->getLoc();
1972  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1973  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1974  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1975  if (reInd.size() == 1) {
1976  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1978  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1979  expandedPaddedSizes[reInd[0]] = paddedSize;
1980  }
1981  newLow.push_back(l);
1982  newHigh.push_back(h);
1983  }
1984 
1985  RankedTensorType collapsedPaddedType =
1986  paddedType.clone(collapsedPaddedShape);
1987  auto newPadOp = tensor::PadOp::create(
1988  rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1989  padOp.getConstantPaddingValue(), padOp.getNofold());
1990 
1991  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1992  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1993  expandedPaddedSizes);
1994 
1995  return success();
1996  }
1997 
1998 private:
1999  ControlFusionFn controlFoldingReshapes;
2000 };
2001 
2002 /// Pattern to collapse dimensions.
2003 template <typename LinalgType>
2004 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2005 public:
2006  CollapseLinalgDimensions(MLIRContext *context,
2007  GetCollapsableDimensionsFn collapseDimensions,
2008  PatternBenefit benefit = 1)
2009  : OpRewritePattern<LinalgType>(context, benefit),
2010  controlCollapseDimension(std::move(collapseDimensions)) {}
2011 
2012  LogicalResult matchAndRewrite(LinalgType op,
2013  PatternRewriter &rewriter) const override {
2014  SmallVector<ReassociationIndices> collapsableIterationDims =
2015  controlCollapseDimension(op);
2016  if (collapsableIterationDims.empty())
2017  return failure();
2018 
2019  // Check if the specified list of dimensions to collapse is a valid list.
2020  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2021  collapsableIterationDims)) {
2022  return rewriter.notifyMatchFailure(
2023  op, "specified dimensions cannot be collapsed");
2024  }
2025 
2026  std::optional<CollapseResult> collapseResult =
2027  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2028  if (!collapseResult) {
2029  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2030  }
2031  rewriter.replaceOp(op, collapseResult->results);
2032  return success();
2033  }
2034 
2035 private:
2036  GetCollapsableDimensionsFn controlCollapseDimension;
2037 };
2038 
2039 } // namespace
2040 
2041 //===---------------------------------------------------------------------===//
2042 // Methods and patterns that fuse constants with linalg.generic operations.
2043 //===---------------------------------------------------------------------===//
2044 
2045 namespace {
2046 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2047 /// handle cases where the constant is not single-valued.
2048 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2049 public:
2050  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2051  : OpRewritePattern<GenericOp>(context, benefit) {}
2052 
2053  LogicalResult matchAndRewrite(GenericOp genericOp,
2054  PatternRewriter &rewriter) const override {
2055  if (!genericOp.hasPureTensorSemantics())
2056  return failure();
2057  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2058  Operation *def = opOperand->get().getDefiningOp();
2059  TypedAttr constantAttr;
2060  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2061  {
2062  DenseElementsAttr splatAttr;
2063  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2064  splatAttr.isSplat() &&
2065  splatAttr.getType().getElementType().isIntOrFloat()) {
2066  constantAttr = splatAttr.getSplatValue<TypedAttr>();
2067  return true;
2068  }
2069  }
2070  {
2071  IntegerAttr intAttr;
2072  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2073  constantAttr = intAttr;
2074  return true;
2075  }
2076  }
2077  {
2078  FloatAttr floatAttr;
2079  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2080  constantAttr = floatAttr;
2081  return true;
2082  }
2083  }
2084  return false;
2085  };
2086 
2087  auto resultValue = dyn_cast<OpResult>(opOperand->get());
2088  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2089  continue;
2090 
2091  // The operands and the indexing_maps of the fused operation the same as
2092  // the operands and indexing_maps of the generic operations with the
2093  // values at the constant index dropped.
2094  SmallVector<AffineMap> fusedIndexMaps;
2095  SmallVector<Value> fusedOperands;
2096  SmallVector<Location> fusedLocs{genericOp.getLoc()};
2097  fusedIndexMaps.reserve(genericOp->getNumOperands());
2098  fusedOperands.reserve(genericOp.getNumDpsInputs());
2099  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2100  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2101  if (inputOperand == opOperand)
2102  continue;
2103  Value inputValue = inputOperand->get();
2104  fusedIndexMaps.push_back(
2105  genericOp.getMatchingIndexingMap(inputOperand));
2106  fusedOperands.push_back(inputValue);
2107  fusedLocs.push_back(inputValue.getLoc());
2108  }
2109  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2110  fusedIndexMaps.push_back(
2111  genericOp.getMatchingIndexingMap(&outputOperand));
2112 
2113  // Check if the operation shapes to loops map is computable.
2114  if (!inversePermutation(
2115  concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2116  return rewriter.notifyMatchFailure(
2117  genericOp, "fused op loop bound computation failed");
2118  }
2119 
2120  // Create a constant scalar value from the splat constant.
2121  Value scalarConstant =
2122  arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
2123 
2124  SmallVector<Value> outputOperands = genericOp.getOutputs();
2125  auto fusedOp =
2126  GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
2127  genericOp->getResultTypes(),
2128  /*inputs=*/fusedOperands,
2129  /*outputs=*/outputOperands,
2130  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2131  genericOp.getIteratorTypes(),
2132  /*doc=*/nullptr,
2133  /*library_call=*/nullptr);
2134 
2135  // Map the block argument corresponding to the replaced argument with the
2136  // scalar constant.
2137  Region &region = genericOp->getRegion(0);
2138  Block &entryBlock = *region.begin();
2139  IRMapping mapping;
2140  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2141  scalarConstant);
2142  Region &fusedRegion = fusedOp->getRegion(0);
2143  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2144  mapping);
2145  rewriter.replaceOp(genericOp, fusedOp->getResults());
2146  return success();
2147  }
2148  return failure();
2149  }
2150 };
2151 
2152 } // namespace
2153 
2154 //===---------------------------------------------------------------------===//
2155 // Miscellaneous patterns that help fusion.
2156 //===---------------------------------------------------------------------===//
2157 
2158 namespace {
2159 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2160 /// value of the `outs` operand is not used within the op. This is only
2161 /// implemented for `linalg.generic` operations for now, but should hold for all
2162 /// linalg structured ops.
2163 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2165 
2166  LogicalResult matchAndRewrite(GenericOp op,
2167  PatternRewriter &rewriter) const override {
2168  rewriter.startOpModification(op);
2169  bool modifiedOutput = false;
2170  Location loc = op.getLoc();
2171  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2172  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2173  Value operandVal = opOperand.get();
2174  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2175  if (!operandType)
2176  continue;
2177 
2178  // If outs is sparse, leave it to the sparsifier.
2180  continue;
2181 
2182  // If outs is already an `empty` operation, nothing to do.
2183  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2184  if (definingOp)
2185  continue;
2186  modifiedOutput = true;
2187  SmallVector<OpFoldResult> mixedSizes =
2188  tensor::getMixedSizes(rewriter, loc, operandVal);
2189  Value emptyTensor = tensor::EmptyOp::create(
2190  rewriter, loc, mixedSizes, operandType.getElementType());
2191  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2192  }
2193  }
2194  if (!modifiedOutput) {
2195  rewriter.cancelOpModification(op);
2196  return failure();
2197  }
2198  rewriter.finalizeOpModification(op);
2199  return success();
2200  }
2201 };
2202 
2203 /// Fold linalg.fill into linalg.generic
2204 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2206 
2207  LogicalResult matchAndRewrite(GenericOp genericOp,
2208  PatternRewriter &rewriter) const override {
2209  if (!genericOp.hasPureTensorSemantics())
2210  return failure();
2211  bool fillFound = false;
2212  Block &payload = genericOp.getRegion().front();
2213  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2214  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2215  continue;
2216  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2217  if (!fillOp)
2218  continue;
2219  fillFound = true;
2220  Value fillVal = fillOp.value();
2221  auto resultType =
2222  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2223  Value convertedVal =
2224  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2225  /*isUnsignedCast =*/false);
2226  rewriter.replaceAllUsesWith(
2227  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2228  }
2229  return success(fillFound);
2230  }
2231 };
2232 } // namespace
2233 
2236  const ControlFusionFn &controlFoldingReshapes) {
2237  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2238  controlFoldingReshapes);
2239  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2240  controlFoldingReshapes);
2241  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2242  controlFoldingReshapes);
2243 }
2244 
2247  const ControlFusionFn &controlFoldingReshapes) {
2248  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2249  controlFoldingReshapes);
2250  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2251  patterns.getContext(), controlFoldingReshapes);
2252  patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2253  controlFoldingReshapes);
2254 }
2255 
2258  const ControlFusionFn &controlElementwiseOpsFusion) {
2259  auto *context = patterns.getContext();
2260  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2261  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2262  RemoveOutsDependency>(context);
2263  // Add the patterns that clean up dead operands and results.
2265 }
2266 
2269  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2270  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2271  CollapseLinalgDimensions<linalg::CopyOp>>(
2272  patterns.getContext(), controlCollapseDimensions);
2273 }
2274 
2275 //===---------------------------------------------------------------------===//
2276 // Passes
2277 //===---------------------------------------------------------------------===//
2278 
2279 namespace {
2280 
2281 /// Pass that fuses generic ops on tensors. Used only for testing.
2282 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2283 // patterns added here heavily depends on the cost function used. Having an
2284 // opinionated pass of this form is not recommended. Deprecate this pass in
2285 // favor of test passes that check the functionality of each of the patterns
2286 // added here individually.
2287 struct LinalgElementwiseOpFusionPass
2288  : public impl::LinalgElementwiseOpFusionPassBase<
2289  LinalgElementwiseOpFusionPass> {
2290  using impl::LinalgElementwiseOpFusionPassBase<
2291  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2292  void runOnOperation() override {
2293  Operation *op = getOperation();
2294  MLIRContext *context = op->getContext();
2295  RewritePatternSet patterns(context);
2296 
2297  // Add folding with reshape by expansion patterns.
2298  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2299  Operation *producer = fusedOperand->get().getDefiningOp();
2300  return producer && producer->hasOneUse();
2301  };
2302 
2303  // Add elementwise op fusion patterns.
2307 
2308  // General canonicalization patterns.
2309  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2310  GenericOp::getCanonicalizationPatterns(patterns, context);
2311  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2312  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2313  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2314  patterns);
2315 
2316  // Add constant folding patterns.
2318 
2319  // Use TopDownTraversal for compile time reasons.
2320  (void)applyPatternsGreedily(op, std::move(patterns),
2322  }
2323 };
2324 
2325 } // namespace
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:26
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:317
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:579
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:632
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:235
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1902
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1938
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:829
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:561
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:563
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.