MLIR  19.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
26 #include <optional>
27 #include <utility>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
31 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 
37 //===---------------------------------------------------------------------===//
38 // Methods and patterns that fuse elementwise `linalg.generic` operations.
39 //===---------------------------------------------------------------------===//
40 
41 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
42 /// the `producer` to use in the fused operation given the indexing map of the
43 /// result of the producer in the consumer.
45  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
46  AffineMap fusedConsumerArgIndexMap) {
47  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
48  // from consumer loop -> consumer arg tensor index/producer result tensor
49  // index. The fused loop is same as the consumer loop. For each producer arg
50  // the indexing map to be computed is a map from consumer loop -> producer
51  // arg tensor index.
52  // producerResultIndexMap is a map from producer loop -> tensor index.
53  // Compute the inverse to get map from tensor index -> producer loop.
54  // The inverse is a map from producer result tensor index -> producer loop.
55  AffineMap invProducerResultIndexMap =
56  inversePermutation(producerResultIndexMap);
57  assert(invProducerResultIndexMap &&
58  "expected producer result indexing map to be invertible");
59 
60  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
61  // argMap is a map from producer loop -> producer arg tensor index.
62  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
63 
64  // Compose argMap with invProducerResultIndexMap to get a map from
65  // producer result tensor index -> producer arg tensor index.
66  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
67 
68  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
69  // consumer loop/ fused loop -> producer arg tensor index.
70  return t1.compose(fusedConsumerArgIndexMap);
71 }
72 
73 /// Returns a set of indices of the producer's results which would
74 /// be preserved after the fusion.
75 llvm::SmallDenseSet<int>
77  GenericOp consumer) {
78  llvm::SmallDenseSet<int> preservedProducerResults;
79  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
80  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
81  if (producer.payloadUsesValueFromOperand(outputOperand) ||
82  !producer.canOpOperandsBeDropped(outputOperand) ||
83  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
84  return user != consumer.getOperation();
85  })) {
86  preservedProducerResults.insert(producerResult.index());
87  }
88  }
89  return preservedProducerResults;
90 }
91 
92 /// Conditions for elementwise fusion of generic operations.
94  if (!fusedOperand)
95  return false;
96 
97  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
98  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
99 
100  // Check producer and consumer are generic ops.
101  if (!producer || !consumer)
102  return false;
103 
104  // Consumer can have mixed semantics, just check operand itself has tensor
105  // type. Producer must have full tensor semantics to avoid potential
106  // aliasing between producer and consumer memrefs.
107  if (!producer.hasPureTensorSemantics() ||
108  !isa<RankedTensorType>(fusedOperand->get().getType()))
109  return false;
110 
111  // Verify that
112  // - the producer has all "parallel" iterator type.
113  if (producer.getNumParallelLoops() != producer.getNumLoops())
114  return false;
115 
116  // Only allow fusing the producer of an input operand for now.
117  // TODO: allow fusing the producer of an output operand.
118  if (!consumer.isDpsInput(fusedOperand))
119  return false;
120 
121  // Get the consumer index map. The number of results of the consumer index
122  // map must match the number of loops of the producer.
123  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
124  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
125  return false;
126 
127  // Finally the index_map for the result must be invertible. For now just
128  // verify it is a permutation.
129  AffineMap producerResultIndexMap =
130  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
131  if (!producerResultIndexMap.isPermutation())
132  return false;
133 
134  // Ensure that the fusion does not remove size information required to
135  // get the loop bounds. For non-reduction generics, this is trivially the
136  // case due to the output operand. For reductions, we need to check that after
137  // the fusion, each loop dimension has at least one input that defines it.
138  if ((consumer.getNumReductionLoops())) {
139  BitVector coveredDims(consumer.getNumLoops(), false);
140 
141  auto addToCoveredDims = [&](AffineMap map) {
142  for (auto result : map.getResults())
143  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
144  coveredDims[dimExpr.getPosition()] = true;
145  };
146 
147  for (auto pair :
148  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
149  Value operand = std::get<0>(pair);
150  if (operand == fusedOperand->get())
151  continue;
152  AffineMap operandMap = std::get<1>(pair);
153  addToCoveredDims(operandMap);
154  }
155 
156  for (OpOperand *operand : producer.getDpsInputOperands()) {
157  AffineMap newIndexingMap =
159  operand, producerResultIndexMap, consumerIndexMap);
160  addToCoveredDims(newIndexingMap);
161  }
162  if (!coveredDims.all())
163  return false;
164  }
165 
166  return true;
167 }
168 
169 /// Generate the region of the fused tensor operation. The region of the fused
170 /// op must be empty.
172  RewriterBase &rewriter, GenericOp fusedOp,
173  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
174  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
175  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
176  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
177  // Build the region of the fused op.
178  Block &producerBlock = producer->getRegion(0).front();
179  Block &consumerBlock = consumer->getRegion(0).front();
180  OpBuilder::InsertionGuard guard(rewriter);
181  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
182  IRMapping mapper;
183 
184  // 2. Add an index operation for every fused loop dimension and use the
185  // `consumerToProducerLoopsMap` to map the producer indices.
186  if (producer.hasIndexSemantics()) {
187  // Add an index operation for every fused loop dimension.
188  unsigned numFusedOpLoops =
189  std::max(producer.getNumLoops(), consumer.getNumLoops());
190  SmallVector<Value> fusedIndices;
191  fusedIndices.reserve(numFusedOpLoops);
192  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
193  std::back_inserter(fusedIndices), [&](uint64_t dim) {
194  return rewriter.create<IndexOp>(producer.getLoc(), dim);
195  });
196  for (IndexOp indexOp :
197  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
198  Value newIndex = rewriter.create<affine::AffineApplyOp>(
199  producer.getLoc(),
200  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
201  mapper.map(indexOp.getResult(), newIndex);
202  }
203  }
204  // TODO: allow fusing the producer of an output operand.
205  assert(consumer.isDpsInput(fusedOperand) &&
206  "expected producer of input operand");
207  // 3. Consumer input operands up to consumerIdx (exclusive).
208  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
209  fusedOperand->getOperandNumber())) // input assumption.
210  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
211 
212  // Replacing consumerIdx requires getting the cloned, yielded, value from
213  // the (cloned) producer block. This happens in step 9.
214 
215  // 4. Splice in producer's input operands.
216  for (BlockArgument bbArg :
217  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
218  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
219 
220  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
221  for (BlockArgument bbArg :
222  consumerBlock.getArguments()
223  .take_front(consumer.getNumDpsInputs())
224  .drop_front(fusedOperand->getOperandNumber() + 1))
225  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
226 
227  // 6. All of the producer's output operands
228  for (const auto &bbArg : llvm::enumerate(
229  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
230  if (!preservedProducerResults.count(bbArg.index()))
231  continue;
232  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
233  bbArg.value().getLoc()));
234  }
235 
236  // 7. All of consumer's output operands.
237  for (BlockArgument bbArg :
238  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
239  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
240 
241  // 8. Clone all producer operations except for the yield and index operations
242  // to the fused operation.
243  for (auto &op : producerBlock.without_terminator()) {
244  if (!isa<IndexOp>(op))
245  rewriter.clone(op, mapper);
246  }
247  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
248  // forward the yield operand.
249  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
250  unsigned producerResultNumber =
251  cast<OpResult>(fusedOperand->get()).getResultNumber();
252  Value replacement =
253  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
254 
255  // Sanity checks, if replacement is not already in the mapper then it must be
256  // produced outside.
257  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
258  if (auto bb = dyn_cast<BlockArgument>(replacement))
259  assert(bb.getOwner() != &producerBlock &&
260  "yielded block argument must have been mapped");
261  else
262  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
263  "yielded value must have been mapped");
264  }
265  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
266  replacement);
267  // 10. Clone operations from the consumer to the fused op.
268  for (auto &op : consumerBlock.without_terminator())
269  rewriter.clone(op, mapper);
270 
271  // 11. Include the final yield (which is the remapped values for all the
272  // yield)
273  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
274  SmallVector<Value> fusedYieldValues;
275  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
276  consumerYieldOp.getNumOperands());
277  for (const auto &producerYieldVal :
278  llvm::enumerate(producerYieldOp.getOperands())) {
279  if (preservedProducerResults.count(producerYieldVal.index()))
280  fusedYieldValues.push_back(
281  mapper.lookupOrDefault(producerYieldVal.value()));
282  }
283  for (auto consumerYieldVal : consumerYieldOp.getOperands())
284  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
285  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
286 
287  // Sanity checks.
288  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
289  "Ill-formed GenericOp region");
290 }
291 
292 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
294  OpOperand *fusedOperand) {
295  assert(areElementwiseOpsFusable(fusedOperand) &&
296  "expected elementwise operation pre-conditions to pass");
297  auto producerResult = cast<OpResult>(fusedOperand->get());
298  auto producer = cast<GenericOp>(producerResult.getOwner());
299  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
300  // TODO: allow fusing the producer of an output operand.
301  assert(consumer.isDpsInput(fusedOperand) &&
302  "expected producer of input operand");
303  /// Find the results of the producer that have uses outside of the consumer.
304  llvm::SmallDenseSet<int> preservedProducerResults =
306  consumer);
307 
308  // Compute the fused operands list and indexing maps.
309  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
310  SmallVector<Type> fusedResultTypes;
311  SmallVector<AffineMap> fusedIndexMaps;
312  fusedInputOperands.reserve(producer.getNumDpsInputs() +
313  consumer.getNumDpsInputs());
314  fusedOutputOperands.reserve(preservedProducerResults.size() +
315  consumer.getNumDpsInits());
316  fusedResultTypes.reserve(preservedProducerResults.size() +
317  consumer.getNumDpsInits());
318  fusedIndexMaps.reserve(producer->getNumOperands() +
319  consumer->getNumOperands());
320  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
321  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
322  auto consumerInputs = consumer.getDpsInputOperands();
323  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
324  return operand == fusedOperand;
325  });
326  assert(it != consumerInputs.end() && "expected to find the consumer operand");
327  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
328  fusedInputOperands.push_back(opOperand->get());
329  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
330  }
331  // 4. Splice in producer's input operands/maps.
332  AffineMap producerResultIndexMap =
333  producer.getIndexingMapMatchingResult(producerResult);
334  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
335  fusedInputOperands.push_back(opOperand->get());
336  // Compute indexing maps for the producer args in the fused operation.
338  opOperand, producerResultIndexMap,
339  consumer.getMatchingIndexingMap(fusedOperand));
340  fusedIndexMaps.push_back(map);
341  }
342  // 5. Remaining consumer's input operands/maps (drop past index
343  // `consumerIdx`).
344  for (OpOperand *opOperand :
345  llvm::make_range(std::next(it), consumerInputs.end())) {
346  fusedInputOperands.push_back(opOperand->get());
347  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
348  }
349 
350  // 6. Collect all of the producer outputs.
351  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
352  if (!preservedProducerResults.count(opOperand.index()))
353  continue;
354 
355  fusedOutputOperands.push_back(opOperand.value().get());
357  &opOperand.value(), producerResultIndexMap,
358  consumer.getMatchingIndexingMap(fusedOperand));
359  fusedIndexMaps.push_back(map);
360  fusedResultTypes.push_back(opOperand.value().get().getType());
361  }
362 
363  // 7. All of consumer's output operands (skip operands: added by the builder).
364  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
365  fusedOutputOperands.push_back(opOperand.get());
366  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
367  Type resultType = opOperand.get().getType();
368  if (!isa<MemRefType>(resultType))
369  fusedResultTypes.push_back(resultType);
370  }
371 
372  // Generate the fused op.
373  auto fusedOp = rewriter.create<GenericOp>(
374  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
375  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
376  consumer.getIteratorTypes(),
377  /*doc=*/nullptr,
378  /*library_call=*/nullptr);
379  if (!fusedOp.getShapesToLoopsMap()) {
380  // Fused op has invalid indexing maps. Typically this means something is off
381  // in the input, but going ahead here would result in verification errors.
382  // So cleanup and abort.
383  rewriter.eraseOp(fusedOp);
384  return rewriter.notifyMatchFailure(
385  fusedOp, "fused op failed loop bound computation check");
386  }
387 
388  // Construct an AffineMap from consumer loops to producer loops.
389  // consumer loop -> tensor index
390  AffineMap consumerResultIndexMap =
391  consumer.getMatchingIndexingMap(fusedOperand);
392  // tensor index -> producer loop
393  AffineMap invProducerResultIndexMap =
394  inversePermutation(producerResultIndexMap);
395  assert(invProducerResultIndexMap &&
396  "expected producer result indexig map to be invertible");
397  // consumer loop -> producer loop
398  AffineMap consumerToProducerLoopsMap =
399  invProducerResultIndexMap.compose(consumerResultIndexMap);
400 
402  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
403  consumer.getNumLoops(), preservedProducerResults);
405  result.fusedOp = fusedOp;
406  int resultNum = 0;
407  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
408  if (preservedProducerResults.count(index))
409  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
410  for (auto consumerResult : consumer->getResults())
411  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
412  return result;
413 }
414 
415 namespace {
416 /// Patterns to fuse a generic op, with the producer of its operands.
417 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
418 public:
419  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
420  PatternBenefit benefit = 1)
421  : OpRewritePattern<GenericOp>(context, benefit),
422  controlFn(std::move(fun)) {}
423 
424  LogicalResult matchAndRewrite(GenericOp genericOp,
425  PatternRewriter &rewriter) const override {
426  // Find the first operand that is defined by another generic op on tensors.
427  for (OpOperand &opOperand : genericOp->getOpOperands()) {
428  if (!areElementwiseOpsFusable(&opOperand))
429  continue;
430  if (!controlFn(&opOperand))
431  continue;
432 
433  Operation *producer = opOperand.get().getDefiningOp();
434 
435  // Find the producer of the operand.
436  FailureOr<ElementwiseOpFusionResult> fusionResult =
437  fuseElementwiseOps(rewriter, &opOperand);
438  if (failed(fusionResult))
439  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
440 
441  // Perform the fusion.
442  for (auto [origVal, replacement] : fusionResult->replacements) {
443  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
444  // Only replace consumer uses.
445  return use.get().getDefiningOp() != producer;
446  });
447  }
448  rewriter.eraseOp(genericOp);
449  return success();
450  }
451  return failure();
452  }
453 
454 private:
455  ControlFusionFn controlFn;
456 };
457 } // namespace
458 
459 //===---------------------------------------------------------------------===//
460 // Methods and patterns that fuse reshape ops with elementwise operations by
461 // expanding the dimensionality of the elementwise operations.
462 //===---------------------------------------------------------------------===//
463 
464 /// Conditions for folding a structured linalg operation with a reshape op by
465 /// expanding the iteration space dimensionality for tensor operations. These
466 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
467 /// the following fusion pattern.
468 ///
469 /// Consider
470 ///
471 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
472 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
473 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
474 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
475 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
476 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
477 ///
478 /// The reshape can be folded into the `linalgOp` if its loop dimensionality
479 /// is increased to match the result (operand) of the tensor.expand_shape.
480 /// The indexing_map of the fused tensor in the `linalgOp` and the
481 /// reassociation map helps compute the indexing maps of the modified op.
482 /// For the above example, based on the reassociation map it
483 /// can be concluded that
484 ///
485 /// - The loop used to access the first dimension of the fused tensor is split
486 /// into two.
487 /// - The loop used to access the second dimension of the fused tensor is kept
488 /// as is.
489 /// - The loop used to access the third dimension of the fused tensor is split
490 /// into three.
491 ///
492 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
493 /// op, then
494 ///
495 /// d0 -> e0, e1
496 /// d1 -> e2, e3, e4
497 /// d2 -> e5
498 ///
499 /// substituting this, the structured op can be rewritten as
500 ///
501 /// %d = linalg.generic ins(%0, %1 : )
502 /// indexing_maps =
503 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
504 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
505 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
506 ///
507 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
508 /// to make it consistent
509 ///
510 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
511 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
512 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
513 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
514 ///
515 /// The added reshapes are again expanding patterns, so they will get fused
516 /// with its producers if possible.
517 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
518  OpOperand *fusableOpOperand) {
519  // Is fusable only if:
520  // - All the indexing maps for operands and results are projected
521  // permutations.
522  // - The fused tensor is not a scalar.
523  // - All the loops for the reshaped operand are parallel loops.
524  SmallVector<utils::IteratorType> iteratorTypes =
525  linalgOp.getIteratorTypesArray();
526  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
527  return linalgOp.hasPureTensorSemantics() &&
528  llvm::all_of(linalgOp.getIndexingMaps().getValue(),
529  [](Attribute attr) {
530  return cast<AffineMapAttr>(attr)
531  .getValue()
532  .isProjectedPermutation();
533  }) &&
534  operandMap.getNumResults() > 0 &&
535  llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
536  return isParallelIterator(
537  iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
538  });
539 }
540 
541 namespace {
542 /// Information needed to expand a generic operation to fold the reshape with
543 /// it.
544 class ExpansionInfo {
545 public:
546  // Computes the mapping from original dimensions of the op to the dimensions
547  // of the expanded op given the `indexingMap` of the fused operand/result of
548  // the generic op, the `reassocationMaps` of the reshape op and the shape of
549  // the expanded op.
550  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
551  ArrayRef<AffineMap> reassociationMaps,
552  ArrayRef<int64_t> expandedShape,
553  ArrayRef<int64_t> collapsedShape,
554  PatternRewriter &rewriter);
555  unsigned getOrigOpNumDims() const { return reassociation.size(); }
556  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
557  ReassociationIndicesRef getExpandedDims(unsigned i) const {
558  return reassociation[i];
559  }
560  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
561  return expandedShapeMap[i];
562  }
563  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
564 
565 private:
566  /// Reassociation from the dimensions in the original operation to the
567  /// dimension of the expanded operation.
568  SmallVector<ReassociationIndices> reassociation;
569  /// Mapping from extent of loops in the original operation, to the extent of
570  /// loops in the expanded operation.
571  SmallVector<SmallVector<int64_t>> expandedShapeMap;
572  /// Extent of the loop in the original operation.
573  SmallVector<int64_t> originalLoopExtent;
574  unsigned expandedOpNumDims;
575 };
576 } // namespace
577 
578 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
579  OpOperand *fusableOpOperand,
580  ArrayRef<AffineMap> reassociationMaps,
581  ArrayRef<int64_t> expandedShape,
582  ArrayRef<int64_t> collapsedShape,
583  PatternRewriter &rewriter) {
584  if (reassociationMaps.empty())
585  return failure();
586  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
587 
588  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
589  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
590 
591  reassociation.clear();
592  expandedShapeMap.clear();
593  // Compute the number of dimension in the expanded op that correspond to each
594  // dimension of the original op.
595  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
596  expandedShapeMap.resize(fusedIndexMap.getNumDims());
597  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
598  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
599  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
600  numExpandedDims[pos] = foldedDims.getNumResults();
601  ArrayRef<int64_t> shape =
602  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
603  expandedShapeMap[pos].assign(shape.begin(), shape.end());
604  }
605  // The remaining dimensions remain the same.
606  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
607  if (expandedShapeMap[i].empty())
608  expandedShapeMap[i] = {originalLoopExtent[i]};
609 
610  // Compute reassociation map from the original op to the expanded op.
611  unsigned sum = 0;
612  reassociation.reserve(fusedIndexMap.getNumDims());
613  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
614  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
615  reassociation.emplace_back(seq.begin(), seq.end());
616  sum += numFoldedDim.value();
617  }
618  expandedOpNumDims = sum;
619  return success();
620 }
621 
622 /// Expanding the body of a linalg operation requires adaptations of the
623 /// accessed loop indices. Specifically, access of indices in the original
624 /// operation need to be replaced with linearizations of indices in the expanded
625 /// op. That requires the shape of the expanded dimensions to be static (at
626 /// least all but the most significant). For now check that these are all
627 /// statically sized. Note that this could be extended to handle dynamic case,
628 /// but the implementation below uses `affine.apply` which seems to have issues
629 /// when the shapes are not static.
630 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
631  const ExpansionInfo &expansionInfo,
632  PatternRewriter &rewriter) {
633  if (!linalgOp.hasIndexSemantics())
634  return success();
635  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
636  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
637  if (expandedShape.size() == 1)
638  continue;
639  for (int64_t shape : expandedShape.drop_front()) {
640  if (ShapedType::isDynamic(shape)) {
641  return rewriter.notifyMatchFailure(
642  linalgOp, "cannot expand due to index semantics and dynamic dims");
643  }
644  }
645  }
646  return success();
647 }
648 
649 /// Return the indexing map to use in the expanded op for a given the
650 /// `indexingMap` of the original operation.
651 static AffineMap
653  const ExpansionInfo &expansionInfo) {
654  SmallVector<AffineExpr> newExprs;
655  for (AffineExpr expr : indexingMap.getResults()) {
656  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
657  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
658  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
659  return builder.getAffineDimExpr(static_cast<unsigned>(v));
660  }));
661  newExprs.append(expandedExprs.begin(), expandedExprs.end());
662  }
663  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
664  indexingMap.getNumSymbols(), newExprs,
665  builder.getContext());
666 }
667 
668 /// Return the type of the operand/result to use in the expanded op given the
669 /// type in the original op.
670 static RankedTensorType getExpandedType(RankedTensorType originalType,
671  AffineMap indexingMap,
672  const ExpansionInfo &expansionInfo) {
673  SmallVector<int64_t> expandedShape;
674  for (AffineExpr expr : indexingMap.getResults()) {
675  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
676  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
677  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
678  }
679  return RankedTensorType::get(expandedShape, originalType.getElementType());
680 }
681 
682 /// Returns the reassociation maps to use in the `tensor.expand_shape`
683 /// operation to convert the operands of the original operation to operands of
684 /// the expanded operation. The same method is used to compute the
685 /// `tensor.collapse_shape` used to collapse the result of the expanded
686 /// op to get the value that can replace all uses of the results of the original
687 /// op.
690  const ExpansionInfo &expansionInfo) {
691  SmallVector<ReassociationIndices> reassociation;
692  unsigned numReshapeDims = 0;
693  for (AffineExpr expr : indexingMap.getResults()) {
694  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
695  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
696  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
697  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
698  reassociation.emplace_back(std::move(indices));
699  numReshapeDims += numExpandedDims;
700  }
701  return reassociation;
702 }
703 
704 /// Update the body of an expanded linalg operation having index semantics. The
705 /// indices of the original operation need to be recovered by linearizing the
706 /// indices of the correspoding dimensions of the expanded operation. For now it
707 /// is assumed that the shapes of the expanded operation needed for
708 /// linearization are static.
710  Location loc, Region &fusedRegion,
711  const ExpansionInfo &expansionInfo) {
712  // Replace the original indices by the linearization of the expanded indices.
713  for (IndexOp indexOp :
714  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
715  ArrayRef<int64_t> expandedDims =
716  expansionInfo.getExpandedDims(indexOp.getDim());
717  assert(!expandedDims.empty() && "expected valid expansion info");
718 
719  // Skip index operations that are not affected by the expansion.
720  if (expandedDims.size() == 1 &&
721  expandedDims.front() == (int64_t)indexOp.getDim())
722  continue;
723 
724  // Linearize the expanded indices of the original index dimension.
725  OpBuilder::InsertionGuard guard(rewriter);
726  rewriter.setInsertionPointAfter(indexOp);
727  ArrayRef<int64_t> expandedDimsShape =
728  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
729  SmallVector<Value> expandedIndices;
730  expandedIndices.reserve(expandedDims.size() - 1);
731  llvm::transform(
732  expandedDims.drop_front(), std::back_inserter(expandedIndices),
733  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
734  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
735  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
736  assert(!ShapedType::isDynamic(std::get<0>(it)));
737  AffineExpr idx, acc;
738  bindDims(rewriter.getContext(), idx, acc);
739  newIndex = rewriter.create<affine::AffineApplyOp>(
740  indexOp.getLoc(), idx + acc * std::get<0>(it),
741  ValueRange{std::get<1>(it), newIndex});
742  }
743  rewriter.replaceOp(indexOp, newIndex);
744  }
745 }
746 
747 /// Checks if a single dynamic dimension expanded into multiple dynamic
748 /// dimensions.
749 static LogicalResult
750 validateDynamicDimExpansion(LinalgOp linalgOp,
751  const ExpansionInfo &expansionInfo,
752  PatternRewriter &rewriter) {
753  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
754  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
755  if (expandedShape.size() == 1)
756  continue;
757  bool foundDynamic = false;
758  for (int64_t shape : expandedShape) {
759  if (!ShapedType::isDynamic(shape))
760  continue;
761  if (foundDynamic) {
762  return rewriter.notifyMatchFailure(
763  linalgOp, "cannot infer expanded shape with multiple dynamic "
764  "dims in the same reassociation group");
765  }
766  foundDynamic = true;
767  }
768  }
769  return success();
770 }
771 
772 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
773 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
774 /// that those conditions have been satisfied.
775 static std::optional<SmallVector<Value>>
776 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
777  OpOperand *fusableOpOperand,
778  PatternRewriter &rewriter) {
779  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
780  "preconditions for fuse operation failed");
781 
782  Location loc = linalgOp.getLoc();
783  // Check if reshape is expanding or collapsing.
784  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
785  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
786  bool isExpanding = (expandingReshapeOp != nullptr);
787  RankedTensorType expandedType = isExpanding
788  ? expandingReshapeOp.getResultType()
789  : collapsingReshapeOp.getSrcType();
790  RankedTensorType collapsedType = isExpanding
791  ? expandingReshapeOp.getSrcType()
792  : collapsingReshapeOp.getResultType();
793 
794  ExpansionInfo expansionInfo;
795  if (failed(expansionInfo.compute(
796  linalgOp, fusableOpOperand,
797  isExpanding ? expandingReshapeOp.getReassociationMaps()
798  : collapsingReshapeOp.getReassociationMaps(),
799  expandedType.getShape(), collapsedType.getShape(), rewriter)))
800  return std::nullopt;
801 
802  // TODO: With the support of multiple dynamic dims expansion in
803  // tensor.expand_shape op, this case can be handled.
804  if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
805  return std::nullopt;
806 
807  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
808  return std::nullopt;
809 
810  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
811  llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
812  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
813  }));
814 
815  // Set insertion point to the generic op.
816  OpBuilder::InsertionGuard g(rewriter);
817  rewriter.setInsertionPoint(linalgOp);
818 
819  SmallVector<Value> expandedOpOperands;
820  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
821  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
822  if (opOperand == fusableOpOperand) {
823  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
824  : collapsingReshapeOp.getSrc());
825  continue;
826  }
827  if (auto opOperandType =
828  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
829  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
830  RankedTensorType expandedOperandType =
831  getExpandedType(opOperandType, indexingMap, expansionInfo);
832  if (expandedOperandType != opOperand->get().getType()) {
833  // Reshape the operand to get the right type.
834  SmallVector<ReassociationIndices> reassociation =
835  getReassociationForExpansion(indexingMap, expansionInfo);
837  [&](const Twine &msg) {
838  return rewriter.notifyMatchFailure(linalgOp, msg);
839  },
840  opOperandType.getShape(), expandedOperandType.getShape(),
841  reassociation,
842  /*isExpandingReshape=*/true)))
843  return std::nullopt;
844  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
845  loc, expandedOperandType, opOperand->get(), reassociation));
846  continue;
847  }
848  }
849  expandedOpOperands.push_back(opOperand->get());
850  }
851 
852  SmallVector<Value> outputs;
853  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
854  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
855  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
856  RankedTensorType expandedOutputType =
857  getExpandedType(opOperandType, indexingMap, expansionInfo);
858  if (expandedOutputType != opOperand.get().getType()) {
859  SmallVector<ReassociationIndices> reassociation =
860  getReassociationForExpansion(indexingMap, expansionInfo);
862  [&](const Twine &msg) {
863  return rewriter.notifyMatchFailure(linalgOp, msg);
864  },
865  opOperandType.getShape(), expandedOutputType.getShape(),
866  reassociation,
867  /*isExpandingReshape=*/true)))
868  return std::nullopt;
869  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
870  loc, expandedOutputType, opOperand.get(), reassociation));
871  } else {
872  outputs.push_back(opOperand.get());
873  }
874  }
875 
876  // The iterator types of the expanded op are all parallel.
877  SmallVector<utils::IteratorType> iteratorTypes(
878  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
879  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
880  for (auto j : expansionInfo.getExpandedDims(i))
881  iteratorTypes[j] = type;
882 
883  TypeRange resultTypes = ValueRange(outputs).getTypes();
884  auto fusedOp =
885  rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
886  /*inputs=*/expandedOpOperands, outputs,
887  expandedOpIndexingMaps, iteratorTypes);
888  Region &fusedRegion = fusedOp->getRegion(0);
889  Region &originalRegion = linalgOp->getRegion(0);
890  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
891 
892  // Update the index accesses after the expansion.
893  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
894 
895  // Reshape the result values to their original shape if this is a collapsing
896  // reshape folded into its consumer.
897  SmallVector<Value> resultVals;
898  for (OpResult opResult : linalgOp->getOpResults()) {
899  int64_t resultNumber = opResult.getResultNumber();
900  if (resultTypes[resultNumber] != opResult.getType()) {
901  SmallVector<ReassociationIndices> reassociation =
903  linalgOp.getMatchingIndexingMap(
904  linalgOp.getDpsInitOperand(resultNumber)),
905  expansionInfo);
906  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
907  linalgOp.getLoc(), opResult.getType(),
908  fusedOp->getResult(resultNumber), reassociation));
909  } else {
910  resultVals.push_back(fusedOp->getResult(resultNumber));
911  }
912  }
913  // Assuming a single result.
914  return resultVals;
915 }
916 
917 namespace {
918 
919 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
920 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
921 /// in the consumer is expanded.
922 class FoldWithProducerReshapeOpByExpansion
923  : public OpInterfaceRewritePattern<LinalgOp> {
924 public:
925  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
926  ControlFusionFn foldReshapes,
927  PatternBenefit benefit = 1)
928  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
929  controlFoldingReshapes(std::move(foldReshapes)) {}
930 
931  LogicalResult matchAndRewrite(LinalgOp linalgOp,
932  PatternRewriter &rewriter) const override {
933  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
934  tensor::CollapseShapeOp reshapeOp =
935  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
936  if (!reshapeOp)
937  continue;
938  // Fold only if
939  // - The tensor reshape op is folding.
940  // - All constraints of fusing with reshape by expansion are met.
941  if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
942  (!controlFoldingReshapes(opOperand)))
943  continue;
944 
945  std::optional<SmallVector<Value>> replacementValues =
946  fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
947  if (!replacementValues)
948  return failure();
949  rewriter.replaceOp(linalgOp, *replacementValues);
950  return success();
951  }
952  return failure();
953  }
954 
955 private:
956  ControlFusionFn controlFoldingReshapes;
957 };
958 
959 class FoldPadWithProducerReshapeOpByExpansion
960  : public OpRewritePattern<tensor::PadOp> {
961 public:
962  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
963  ControlFusionFn foldReshapes,
964  PatternBenefit benefit = 1)
965  : OpRewritePattern<tensor::PadOp>(context, benefit),
966  controlFoldingReshapes(std::move(foldReshapes)) {}
967 
968  LogicalResult matchAndRewrite(tensor::PadOp padOp,
969  PatternRewriter &rewriter) const override {
970  tensor::CollapseShapeOp reshapeOp =
971  padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
972  if (!reshapeOp)
973  return failure();
974  if (!reshapeOp->hasOneUse())
975  return failure();
976 
977  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
978  return rewriter.notifyMatchFailure(padOp,
979  "fusion blocked by control function");
980  }
981 
982  ArrayRef<int64_t> low = padOp.getStaticLow();
983  ArrayRef<int64_t> high = padOp.getStaticHigh();
984  SmallVector<ReassociationIndices> reassociations =
985  reshapeOp.getReassociationIndices();
986 
987  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
988  if (reInd.size() != 1 && (l != 0 || h != 0))
989  return failure();
990  }
991 
992  SmallVector<OpFoldResult> newLow, newHigh;
993  RankedTensorType expandedType = reshapeOp.getSrcType();
994  RankedTensorType paddedType = padOp.getResultType();
995  SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
996  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
997  if (reInd.size() == 1) {
998  expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
999  }
1000  for (size_t i = 0; i < reInd.size(); ++i) {
1001  newLow.push_back(padOp.getMixedLowPad()[idx]);
1002  newHigh.push_back(padOp.getMixedHighPad()[idx]);
1003  }
1004  }
1005 
1006  Location loc = padOp->getLoc();
1007  RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1008  auto newPadOp = rewriter.create<tensor::PadOp>(
1009  loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1010  padOp.getConstantPaddingValue(), padOp.getNofold());
1011 
1012  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1013  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1014 
1015  return success();
1016  }
1017 
1018 private:
1019  ControlFusionFn controlFoldingReshapes;
1020 };
1021 
1022 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1023 /// by expanding the dimensionality of the loop in the producer op.
1024 struct FoldReshapeWithGenericOpByExpansion
1025  : public OpRewritePattern<tensor::ExpandShapeOp> {
1026 
1027  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1028  ControlFusionFn foldReshapes,
1029  PatternBenefit benefit = 1)
1030  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1031  controlFoldingReshapes(std::move(foldReshapes)) {}
1032 
1033  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1034  PatternRewriter &rewriter) const override {
1035  // Fold only if all constraints of fusing with reshape by expansion are met.
1036  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1037  if (!producerResult) {
1038  return rewriter.notifyMatchFailure(reshapeOp,
1039  "source not produced by an operation");
1040  }
1041 
1042  auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1043  if (!producer) {
1044  return rewriter.notifyMatchFailure(reshapeOp,
1045  "producer not a generic op");
1046  }
1047 
1049  producer,
1050  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1051  return rewriter.notifyMatchFailure(
1052  reshapeOp, "failed preconditions of fusion with producer generic op");
1053  }
1054 
1055  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1056  return rewriter.notifyMatchFailure(reshapeOp,
1057  "fusion blocked by control function");
1058  }
1059 
1060  std::optional<SmallVector<Value>> replacementValues =
1062  producer, reshapeOp,
1063  producer.getDpsInitOperand(producerResult.getResultNumber()),
1064  rewriter);
1065  if (!replacementValues) {
1066  return rewriter.notifyMatchFailure(reshapeOp,
1067  "fusion by expansion failed");
1068  }
1069 
1070  // Find the replacement for the reshape op. Since the replacements have the
1071  // same type as the returns of the original generic op, the consumer reshape
1072  // op can be replaced by the source of the collapse_shape op that defines
1073  // the replacement.
1074  Value reshapeReplacement =
1075  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1076  .getResultNumber()];
1077  if (auto collapseOp =
1078  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1079  reshapeReplacement = collapseOp.getSrc();
1080  }
1081  rewriter.replaceOp(reshapeOp, reshapeReplacement);
1082  rewriter.replaceOp(producer, *replacementValues);
1083  return success();
1084  }
1085 
1086 private:
1087  ControlFusionFn controlFoldingReshapes;
1088 };
1089 } // namespace
1090 
1091 //===---------------------------------------------------------------------===//
1092 // Methods and patterns to fuse reshape with linalg.generic operations by
1093 // contraction of dimensions.
1094 //===---------------------------------------------------------------------===//
1095 
1096 /// For a given list of indices in the range of the `indexingMap` that are
1097 /// folded, return the indices of the corresponding domain. Return
1098 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1099 /// reassociation are distinct.
1100 static ReassociationIndices
1102  ReassociationIndicesRef rangeReassociation) {
1103  assert(indexingMap.isProjectedPermutation() &&
1104  "expected projected permutation");
1105 
1106  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1107  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1108  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1109  }));
1110  // The projected permutation semantics ensures that there is no repetition of
1111  // the domain indices.
1112  return domainReassociation;
1113 }
1114 
1115 /// For a given `dimSequence`, check if the sequence is conserved in the
1116 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1117 /// Non-existence of the sequence returns true as well.
1119  ReassociationIndicesRef dimSequence) {
1120  assert(!dimSequence.empty() &&
1121  "expected non-empty list for dimension sequence");
1122  assert(indexingMap.isProjectedPermutation() &&
1123  "expected indexing map to be projected permutation");
1124 
1125  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1126  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1127 
1128  unsigned dimSequenceStart = dimSequence[0];
1129  for (const auto &expr : enumerate(indexingMap.getResults())) {
1130  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1131  // 1. Check if this start of the sequence.
1132  if (dimInMapStart == dimSequenceStart) {
1133  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1134  return false;
1135  // 1a. Check if sequence is preserved.
1136  for (const auto &dimInSequence : enumerate(dimSequence)) {
1137  unsigned dimInMap =
1138  cast<AffineDimExpr>(
1139  indexingMap.getResult(expr.index() + dimInSequence.index()))
1140  .getPosition();
1141  if (dimInMap != dimInSequence.value())
1142  return false;
1143  }
1144  // Found the sequence. Projected permutation
1145  // enforces that all AffineDimExprs in the result are unique, so no
1146  // further checks are needed.
1147  return true;
1148  }
1149  // 2. If position in the expr (which is of type AffineDimExpr) is part
1150  // of sequence, return false here. This implies the entire sequence does not
1151  // exist in the indexing map.
1152  if (sequenceElements.count(dimInMapStart))
1153  return false;
1154  }
1155  // 3. No element of sequence found. Return true.
1156  return true;
1157 }
1158 
1161  return llvm::all_of(maps, [&](AffineMap map) {
1162  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1163  return isDimSequencePreserved(map, dimSequence);
1164  });
1165  });
1166 }
1167 
1168 // Return the list of dimensions of the iteration domain that can be
1169 // collapsed to allow for fusion with the a producer that is an expand_shape
1170 // operation. If all dimensions created by expansion can be collapsed in the
1171 // iteration space then the reshape is defunct.
1172 //
1173 // Example:
1174 //
1175 // ```mlir
1176 // #map = affine_map<(d0, d1) -> (d0, d1)>
1177 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1178 // %2 = tensor.empty [..] : tensor<?x4xf32>
1179 // %3 = linalg.generic {
1180 // indexing_maps = [#map, #map],
1181 // iterator_types = ["parallel" ,"parallel"]}
1182 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1183 // ```
1184 //
1185 // can be fused by collapsing the dimensions of the iteration space.
1186 //
1187 // ```mlir
1188 // #map = affine_map<(d0) -> (d0)>
1189 // %2 = tensor.empty [..] : tensor<?xf32>
1190 // %3 = linalg.generic {
1191 // indexing_maps = [#map, #map],
1192 // iterator_types = ["parallel"]}
1193 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1194 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1195 // ```
1196 //
1197 // In the following example,
1198 //
1199 // ```mlir
1200 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1201 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1202 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1203 // %2 = tensor.empty [..] : tensor<4x?xf32>
1204 // %2 = linalg.generic {
1205 // indexing_maps = [#map0, #map1],
1206 // iterator_types = ["parallel" ,"parallel"]}
1207 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1208 // ```
1209 //
1210 // the reshape cannot be fused with the generic op by collapsing the op
1211 // dimensions since the indexing maps will have to contain mods and divs
1212 // to preserve the accesses pattern. When no dimensions of the iteration
1213 // space are collapsable and empty vector is returned.
1215 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1216  ArrayRef<ReassociationIndices> reassociation) {
1217  // Some basic checks for this fusion to be valid.
1218  if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
1219  return {};
1220 
1221  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1222  return map.isProjectedPermutation();
1223  })) {
1224  return {};
1225  }
1226 
1227  // Compute all the loops with the reduction iterator types.
1228  SmallVector<unsigned> reductionDims;
1229  genericOp.getReductionDims(reductionDims);
1230 
1231  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1232  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1233  auto iteratorTypes = genericOp.getIteratorTypesArray();
1234  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1235  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1236  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1237 
1238  // Ignore dims that are not folded.
1239  if (foldedRangeDims.size() == 1)
1240  continue;
1241 
1242  ReassociationIndices foldedIterationSpaceDims =
1243  getDomainReassociation(indexingMap, foldedRangeDims);
1244 
1245  // Check that the folded iteration dims do not contain already processed
1246  // dims.
1247  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1248  return processedIterationDims.count(dim);
1249  }))
1250  continue;
1251 
1252  // Check that all folded iterator types are all parallel or all reductions.
1253  utils::IteratorType startIteratorType =
1254  iteratorTypes[foldedIterationSpaceDims[0]];
1255  if (!isParallelIterator(startIteratorType) &&
1256  !isReductionIterator(startIteratorType))
1257  continue;
1258  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1259  return iteratorTypes[dim] != startIteratorType;
1260  }))
1261  continue;
1262 
1263  // If the folded dimensions correspond to a "reduction" iterator type,
1264  // the folded dimensions need to be "in-order". Strictly speaking this is
1265  // not necessary, for reductions that are associative and commutative, but
1266  // using a more strict definition of reduction for now.
1267  if (isReductionIterator(startIteratorType)) {
1268  bool isContiguous = false;
1269  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1270  // Move window in `reductionDims` to start of the folded iteration dims.
1271  if (startDim.value() != foldedIterationSpaceDims[0])
1272  continue;
1273  // If sizes doesnt match, trivial not contiguous. This condition should
1274  // not be hit.
1275  if (startDim.index() + foldedIterationSpaceDims.size() >
1276  reductionDims.size())
1277  break;
1278  // Check that the contiguity is maintained.
1279  isContiguous = true;
1280  for (const auto &foldedDim :
1281  llvm::enumerate(foldedIterationSpaceDims)) {
1282  if (reductionDims[foldedDim.index() + startDim.index()] !=
1283  foldedDim.value()) {
1284  isContiguous = false;
1285  break;
1286  }
1287  }
1288  break;
1289  }
1290  if (!isContiguous)
1291  continue;
1292  }
1293 
1294  // Check that the sequence is preserved in all indexing maps.
1295  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1296  [&](AffineMap indexingMap) {
1297  return !isDimSequencePreserved(indexingMap,
1298  foldedIterationSpaceDims);
1299  }))
1300  continue;
1301 
1302  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1303  foldedIterationSpaceDims.end());
1304  iterationSpaceReassociation.emplace_back(
1305  std::move(foldedIterationSpaceDims));
1306  }
1307 
1308  return iterationSpaceReassociation;
1309 }
1310 
1311 /// Helper class to carry state while collapsing the `linalg.generic` op.
1312 namespace {
1313 class CollapsingInfo {
1314 public:
1315  LogicalResult initialize(unsigned origNumLoops,
1316  ArrayRef<ReassociationIndices> foldedIterationDims) {
1317  llvm::SmallDenseSet<int64_t, 4> processedDims;
1318  // Find all the dims that are folded.
1319  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1320  if (foldedIterationDim.empty())
1321  continue;
1322  // If the folded dims contain dims already folded, that's illegal
1323  // specification. Repetition within a list is also illegal.
1324  for (auto dim : foldedIterationDim) {
1325  if (dim >= origNumLoops)
1326  return failure();
1327  if (processedDims.count(dim))
1328  return failure();
1329  processedDims.insert(dim);
1330  }
1331  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1332  foldedIterationDim.end());
1333  }
1334  if (processedDims.size() > origNumLoops)
1335  return failure();
1336 
1337  // Add all the preserved dims of the original op as single
1338  // elements to `collapsedOpToOrigOpIterationDim`.
1339  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1340  if (processedDims.count(dim))
1341  continue;
1342  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1343  }
1344 
1345  llvm::sort(collapsedOpToOrigOpIterationDim,
1347  return lhs[0] < rhs[0];
1348  });
1349  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1350  for (const auto &foldedDims :
1351  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1352  for (const auto &dim : enumerate(foldedDims.value()))
1353  origOpToCollapsedOpIterationDim[dim.value()] =
1354  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1355  }
1356  return success();
1357  }
1358 
1359  /// Return mapping from collapsed loop domain to original loop domain.
1360  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1361  return collapsedOpToOrigOpIterationDim;
1362  }
1363 
1364  /// Return mapping from original loop domain to collapsed loop domain. The
1365  /// mapping is a pair. First value is the dimension in the collapsed loop that
1366  /// the original loop is mapped to. Second is the relative position in folded
1367  /// list of this domain. For example if the original loop domain is 3D, and
1368  /// the collapsed loop domain is folding all of it, i.e.
1369  ///
1370  /// ```
1371  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1372  /// ```
1373  ///
1374  /// then
1375  ///
1376  /// ```
1377  /// origOpToCollapsedOpMapping[0] = {0, 0};
1378  /// origOpToCollapsedOpMapping[1] = {0, 1};
1379  /// origOpToCollapsedOpMapping[2] = {0, 2};
1380  /// origOpToCollapsedOpMapping[3] = {1, 0};
1381  /// origOpToCollapsedOpMapping[4] = {1, 1};
1382  /// ```
1383  ///
1384  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1385  return origOpToCollapsedOpIterationDim;
1386  }
1387 
1388  /// Return the collapsed op iteration domain rank.
1389  unsigned getCollapsedOpIterationRank() const {
1390  return collapsedOpToOrigOpIterationDim.size();
1391  }
1392 
1393 private:
1394  /// Map from the iteration domain index in collapsed op to the iteration
1395  /// domain indices in the original op.
1396  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1397 
1398  /// Map from iteration domain index in the original op to the iteration domain
1399  /// index in the collapsed op.
1400  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1401 };
1402 } // namespace
1403 
1404 /// Get the iterator types for the collapsed operation given the original
1405 /// iterator types and collapsed dimensions.
1408  const CollapsingInfo &collapsingInfo) {
1409  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1410  for (ReassociationIndicesRef foldedIterDims :
1411  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1412  assert(!foldedIterDims.empty() &&
1413  "reassociation indices expected to have non-empty sets");
1414  // Just pick the iterator type of the first folded dim. Pre-condition checks
1415  // expected to have checked that iterator types of all folded dimensions are
1416  // the same.
1417  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1418  }
1419  return collapsedIteratorTypes;
1420 }
1421 
1422 /// Compute the indexing map in the collapsed op that corresponds to the given
1423 /// `indexingMap` of the original operation.
1424 static AffineMap
1426  const CollapsingInfo &collapsingInfo) {
1427  MLIRContext *context = indexingMap.getContext();
1428  assert(indexingMap.isProjectedPermutation() &&
1429  "expected indexing map to be projected permutation");
1430  SmallVector<AffineExpr> resultExprs;
1431  auto origOpToCollapsedOpMapping =
1432  collapsingInfo.getOrigOpToCollapsedOpMapping();
1433  for (auto expr : indexingMap.getResults()) {
1434  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1435  // If the dim is not the first of the collapsed dim, do nothing.
1436  if (origOpToCollapsedOpMapping[dim].second != 0)
1437  continue;
1438  // The next n-dims are guaranteed to be collapsed. So just use the
1439  // iteration dimension of the collapsed op.
1440  resultExprs.push_back(
1441  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1442  }
1443  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1444  resultExprs, context);
1445 }
1446 
1447 /// Return the `reassociation` indices to use to collapse the operand when the
1448 /// iteration space of a generic op is collapsed.
1451  const CollapsingInfo &collapsingInfo) {
1452  unsigned counter = 0;
1453  SmallVector<ReassociationIndices> operandReassociation;
1454  auto origOpToCollapsedOpMapping =
1455  collapsingInfo.getOrigOpToCollapsedOpMapping();
1456  auto collapsedOpToOrigOpMapping =
1457  collapsingInfo.getCollapsedOpToOrigOpMapping();
1458  while (counter < indexingMap.getNumResults()) {
1459  unsigned dim =
1460  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1461  // This is the start of a collapsed dimensions of the iteration that
1462  // is gauranteed to be preserved in the indexing map. The number of folded
1463  // dims is obtained from the collapsed op to original op mapping.
1464  unsigned numFoldedDims =
1465  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1466  .size();
1467  if (origOpToCollapsedOpMapping[dim].second == 0) {
1468  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1469  operandReassociation.emplace_back(range.begin(), range.end());
1470  }
1471  counter += numFoldedDims;
1472  }
1473  return operandReassociation;
1474 }
1475 
1476 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1477 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1478  OpOperand *opOperand,
1479  const CollapsingInfo &collapsingInfo,
1480  OpBuilder &builder) {
1481  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1482  SmallVector<ReassociationIndices> operandReassociation =
1483  getOperandReassociation(indexingMap, collapsingInfo);
1484 
1485  // If the number of entries in the reassociation for the operand is same as
1486  // the number of results of the indexing map, then nothing to do for this
1487  // operand.
1488  Value operand = opOperand->get();
1489  if (operandReassociation.size() == indexingMap.getNumResults())
1490  return operand;
1491 
1492  // Insert a reshape to collapse the dimensions.
1493  if (isa<MemRefType>(operand.getType())) {
1494  return builder
1495  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1496  .getResult();
1497  }
1498  return builder
1499  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1500  .getResult();
1501 }
1502 
1503 /// Modify the `linalg.index` operations in the original generic op, to its
1504 /// value in the collapsed operation.
1506  const CollapsingInfo &collapsingInfo,
1507  ValueRange loopRange,
1508  RewriterBase &rewriter) {
1509  OpBuilder::InsertionGuard g(rewriter);
1510  rewriter.setInsertionPointToStart(block);
1511 
1512  // Collect all the original index ops.
1513  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1514 
1515  // For each folded dimension list resolve the original induction variable
1516  // values in terms of the folded dimension induction variable.
1517  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1518  // can be inverted to
1519  // i2 = i_{folded} % d2
1520  // i1 = (i_{folded} / d2) % d1
1521  // i0 = i_{folded} / (d1 * d2)
1522  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1523  for (auto foldedDims :
1524  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1525  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1526  Value newIndexVal =
1527  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1528  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1529  indexReplacementVals[dim] =
1530  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1531  newIndexVal =
1532  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1533  }
1534  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1535  }
1536 
1537  for (auto indexOp : indexOps) {
1538  auto dim = indexOp.getDim();
1539  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1540  }
1541 }
1542 
1544  const CollapsingInfo &collapsingInfo,
1545  RewriterBase &rewriter,
1546  SmallVectorImpl<Value> &inputOperands,
1547  SmallVectorImpl<Value> &outputOperands,
1548  SmallVectorImpl<Type> &resultTypes) {
1549  Location loc = op->getLoc();
1550  inputOperands =
1551  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1552  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1553  rewriter);
1554  });
1555 
1556  // Get the output operands and result types.
1557  resultTypes.reserve(op.getNumDpsInits());
1558  outputOperands.reserve(op.getNumDpsInits());
1559  for (OpOperand &output : op.getDpsInitsMutable()) {
1560  Value newOutput =
1561  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1562  outputOperands.push_back(newOutput);
1563  // If the op has "buffer semantics", then the init operands are ranked
1564  // memrefs and the op has no results.
1565  if (!op.hasPureBufferSemantics())
1566  resultTypes.push_back(newOutput.getType());
1567  }
1568 }
1569 
1570 /// Clone a `LinalgOp` to a collapsed version of same name
1571 template <typename OpTy>
1572 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1573  const CollapsingInfo &collapsingInfo) {
1574  return nullptr;
1575 }
1576 
1577 /// Collapse any `LinalgOp` that does not require any specialization such as
1578 /// indexing_maps, iterator_types, etc.
1579 template <>
1580 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1581  const CollapsingInfo &collapsingInfo) {
1582  SmallVector<Value> inputOperands, outputOperands;
1583  SmallVector<Type> resultTypes;
1584  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1585  outputOperands, resultTypes);
1586 
1587  return clone(
1588  rewriter, origOp, resultTypes,
1589  llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1590 }
1591 
1592 /// Collapse a `GenericOp`
1593 template <>
1595  GenericOp origOp,
1596  const CollapsingInfo &collapsingInfo) {
1597  SmallVector<Value> inputOperands, outputOperands;
1598  SmallVector<Type> resultTypes;
1599  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1600  outputOperands, resultTypes);
1601  SmallVector<AffineMap> indexingMaps(
1602  llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1603  return getCollapsedOpIndexingMap(map, collapsingInfo);
1604  }));
1605 
1607  origOp.getIteratorTypesArray(), collapsingInfo));
1608 
1609  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1610  origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1611  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1612  Block *origOpBlock = &origOp->getRegion(0).front();
1613  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1614  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1615  collapsedOpBlock->getArguments());
1616  return collapsedOp;
1617 }
1618 
1619 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1620  RewriterBase &rewriter) {
1621  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1622  return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1623  } else {
1624  return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1625  }
1626 }
1627 
1628 /// Implementation of fusion with reshape operation by collapsing dimensions.
1629 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1630  LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1631  RewriterBase &rewriter) {
1632  // Bail on trivial no-op cases.
1633  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1634  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1635  return foldedDims.size() <= 1;
1636  }))
1637  return failure();
1638 
1639  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1640  if (hasPureBufferSemantics &&
1641  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1642  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1643  if (!memRefToCollapse)
1644  return true;
1645 
1646  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1647  memRefToCollapse, foldedIterationDims);
1648  }))
1649  return rewriter.notifyMatchFailure(op,
1650  "memref is not guaranteed collapsible");
1651 
1652  CollapsingInfo collapsingInfo;
1653  if (failed(
1654  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1655  return rewriter.notifyMatchFailure(
1656  op, "illegal to collapse specified dimensions");
1657  }
1658 
1659  // Bail on non-canonical ranges.
1660  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1661  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1662  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1663  return cast<IntegerAttr>(attr).getInt() == value;
1664  llvm::APInt actual;
1665  return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1666  actual.getSExtValue() == value;
1667  };
1668  if (!llvm::all_of(loopRanges, [&](Range range) {
1669  return opFoldIsConstantValue(range.offset, 0) &&
1670  opFoldIsConstantValue(range.stride, 1);
1671  })) {
1672  return rewriter.notifyMatchFailure(
1673  op, "expected all loop ranges to have zero start and unit stride");
1674  }
1675 
1676  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1677 
1678  Location loc = op->getLoc();
1679  if (collapsedOp.hasIndexSemantics()) {
1680  // Collect the loop range of the generic op.
1681  OpBuilder::InsertionGuard g(rewriter);
1682  rewriter.setInsertionPoint(collapsedOp);
1683  SmallVector<Value> loopBound =
1684  llvm::map_to_vector(loopRanges, [&](Range range) {
1685  return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1686  });
1687  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1688  collapsingInfo, loopBound, rewriter);
1689  }
1690 
1691  // Insert expanding reshape for the result to get back the original result
1692  // type.
1693  SmallVector<Value> results;
1694  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1695  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1696  auto originalResultType =
1697  cast<ShapedType>(originalResult.value().getType());
1698  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1699  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1700  AffineMap indexingMap =
1701  op.getIndexingMapMatchingResult(originalResult.value());
1702  SmallVector<ReassociationIndices> reassociation =
1703  getOperandReassociation(indexingMap, collapsingInfo);
1704  Value result;
1705  if (isa<MemRefType>(collapsedOpResult.getType())) {
1706  MemRefType expandShapeResultType = MemRefType::get(
1707  originalResultType.getShape(), originalResultType.getElementType());
1708  result = rewriter.create<memref::ExpandShapeOp>(
1709  loc, expandShapeResultType, collapsedOpResult, reassociation);
1710  } else {
1711  result = rewriter.create<tensor::ExpandShapeOp>(
1712  loc, originalResultType, collapsedOpResult, reassociation);
1713  }
1714  results.push_back(result);
1715  } else {
1716  results.push_back(collapsedOpResult);
1717  }
1718  }
1719  return CollapseResult{results, collapsedOp};
1720 }
1721 
1722 namespace {
1723 
1724 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1725 /// contracting dimensions of the loop.
1726 class FoldWithProducerReshapeOpByCollapsing
1727  : public OpRewritePattern<GenericOp> {
1728 public:
1729  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1730  ControlFusionFn foldReshapes,
1731  PatternBenefit benefit = 1)
1732  : OpRewritePattern<GenericOp>(context, benefit),
1733  controlFoldingReshapes(std::move(foldReshapes)) {}
1734 
1735  LogicalResult matchAndRewrite(GenericOp genericOp,
1736  PatternRewriter &rewriter) const override {
1737  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1738  tensor::ExpandShapeOp reshapeOp =
1739  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1740  if (!reshapeOp)
1741  continue;
1742 
1743  SmallVector<ReassociationIndices> collapsableIterationDims =
1744  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1745  reshapeOp.getReassociationIndices());
1746  if (collapsableIterationDims.empty() ||
1747  !controlFoldingReshapes(&opOperand)) {
1748  continue;
1749  }
1750 
1751  std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1752  genericOp, collapsableIterationDims, rewriter);
1753  if (!collapseResult) {
1754  return rewriter.notifyMatchFailure(
1755  genericOp, "failed to do the fusion by collapsing transformation");
1756  }
1757 
1758  rewriter.replaceOp(genericOp, collapseResult->results);
1759  return success();
1760  }
1761  return failure();
1762  }
1763 
1764 private:
1765  ControlFusionFn controlFoldingReshapes;
1766 };
1767 
1768 class FoldPadWithProducerReshapeOpByCollapsing
1769  : public OpRewritePattern<tensor::PadOp> {
1770 public:
1771  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1772  ControlFusionFn foldReshapes,
1773  PatternBenefit benefit = 1)
1774  : OpRewritePattern<tensor::PadOp>(context, benefit),
1775  controlFoldingReshapes(std::move(foldReshapes)) {}
1776 
1777  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1778  PatternRewriter &rewriter) const override {
1779  tensor::ExpandShapeOp reshapeOp =
1780  padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1781  if (!reshapeOp)
1782  return failure();
1783  if (!reshapeOp->hasOneUse())
1784  return failure();
1785 
1786  if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1787  return rewriter.notifyMatchFailure(padOp,
1788  "fusion blocked by control function");
1789  }
1790 
1791  ArrayRef<int64_t> low = padOp.getStaticLow();
1792  ArrayRef<int64_t> high = padOp.getStaticHigh();
1793  SmallVector<ReassociationIndices> reassociations =
1794  reshapeOp.getReassociationIndices();
1795 
1796  for (auto reInd : reassociations) {
1797  if (reInd.size() == 1)
1798  continue;
1799  if (llvm::any_of(reInd, [&](int64_t ind) {
1800  return low[ind] != 0 || high[ind] != 0;
1801  })) {
1802  return failure();
1803  }
1804  }
1805 
1806  SmallVector<OpFoldResult> newLow, newHigh;
1807  RankedTensorType collapsedType = reshapeOp.getSrcType();
1808  RankedTensorType paddedType = padOp.getResultType();
1809  SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1810  SmallVector<OpFoldResult> expandedPaddedSizes(
1811  getMixedValues(reshapeOp.getStaticOutputShape(),
1812  reshapeOp.getOutputShape(), rewriter));
1813  AffineExpr d0, d1, d2;
1814  bindDims(rewriter.getContext(), d0, d1, d2);
1815  auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1816  Location loc = reshapeOp->getLoc();
1817  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1818  OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1819  OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1820  if (reInd.size() == 1) {
1821  collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1823  rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1824  expandedPaddedSizes[reInd[0]] = paddedSize;
1825  }
1826  newLow.push_back(l);
1827  newHigh.push_back(h);
1828  }
1829 
1830  RankedTensorType collapsedPaddedType =
1831  paddedType.clone(collapsedPaddedShape);
1832  auto newPadOp = rewriter.create<tensor::PadOp>(
1833  loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1834  padOp.getConstantPaddingValue(), padOp.getNofold());
1835 
1836  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1837  padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1838  expandedPaddedSizes);
1839 
1840  return success();
1841  }
1842 
1843 private:
1844  ControlFusionFn controlFoldingReshapes;
1845 };
1846 
1847 /// Pattern to collapse dimensions.
1848 template <typename LinalgType>
1849 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1850 public:
1851  CollapseLinalgDimensions(MLIRContext *context,
1852  GetCollapsableDimensionsFn collapseDimensions,
1853  PatternBenefit benefit = 1)
1854  : OpRewritePattern<LinalgType>(context, benefit),
1855  controlCollapseDimension(std::move(collapseDimensions)) {}
1856 
1857  LogicalResult matchAndRewrite(LinalgType op,
1858  PatternRewriter &rewriter) const override {
1859  SmallVector<ReassociationIndices> collapsableIterationDims =
1860  controlCollapseDimension(op);
1861  if (collapsableIterationDims.empty())
1862  return failure();
1863 
1864  // Check if the specified list of dimensions to collapse is a valid list.
1865  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1866  collapsableIterationDims)) {
1867  return rewriter.notifyMatchFailure(
1868  op, "specified dimensions cannot be collapsed");
1869  }
1870 
1871  std::optional<CollapseResult> collapseResult =
1872  collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1873  if (!collapseResult) {
1874  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
1875  }
1876  rewriter.replaceOp(op, collapseResult->results);
1877  return success();
1878  }
1879 
1880 private:
1881  GetCollapsableDimensionsFn controlCollapseDimension;
1882 };
1883 
1884 } // namespace
1885 
1886 //===---------------------------------------------------------------------===//
1887 // Methods and patterns that fuse constants with linalg.generic operations.
1888 //===---------------------------------------------------------------------===//
1889 
1890 namespace {
1891 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1892 /// handle cases where the constant is not single-valued.
1893 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1894 public:
1895  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1896  : OpRewritePattern<GenericOp>(context, benefit) {}
1897 
1898  LogicalResult matchAndRewrite(GenericOp genericOp,
1899  PatternRewriter &rewriter) const override {
1900  if (!genericOp.hasPureTensorSemantics())
1901  return failure();
1902  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1903  Operation *def = opOperand->get().getDefiningOp();
1904  TypedAttr constantAttr;
1905  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1906  {
1907  DenseElementsAttr splatAttr;
1908  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1909  splatAttr.isSplat() &&
1910  splatAttr.getType().getElementType().isIntOrFloat()) {
1911  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1912  return true;
1913  }
1914  }
1915  {
1916  IntegerAttr intAttr;
1917  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1918  constantAttr = intAttr;
1919  return true;
1920  }
1921  }
1922  {
1923  FloatAttr floatAttr;
1924  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1925  constantAttr = floatAttr;
1926  return true;
1927  }
1928  }
1929  return false;
1930  };
1931 
1932  auto resultValue = dyn_cast<OpResult>(opOperand->get());
1933  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1934  continue;
1935 
1936  // The operands and the indexing_maps of the fused operation the same as
1937  // the operands and indexing_maps of the generic operations with the
1938  // values at the constant index dropped.
1939  SmallVector<AffineMap> fusedIndexMaps;
1940  SmallVector<Value> fusedOperands;
1941  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1942  fusedIndexMaps.reserve(genericOp->getNumOperands());
1943  fusedOperands.reserve(genericOp.getNumDpsInputs());
1944  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1945  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1946  if (inputOperand == opOperand)
1947  continue;
1948  Value inputValue = inputOperand->get();
1949  fusedIndexMaps.push_back(
1950  genericOp.getMatchingIndexingMap(inputOperand));
1951  fusedOperands.push_back(inputValue);
1952  fusedLocs.push_back(inputValue.getLoc());
1953  }
1954  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1955  fusedIndexMaps.push_back(
1956  genericOp.getMatchingIndexingMap(&outputOperand));
1957 
1958  // Check if the operation shapes to loops map is computable.
1959  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1960  return rewriter.notifyMatchFailure(
1961  genericOp, "fused op loop bound computation failed");
1962  }
1963 
1964  // Create a constant scalar value from the splat constant.
1965  Value scalarConstant =
1966  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
1967 
1968  SmallVector<Value> outputOperands = genericOp.getOutputs();
1969  auto fusedOp = rewriter.create<GenericOp>(
1970  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1971  /*inputs=*/fusedOperands,
1972  /*outputs=*/outputOperands,
1973  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1974  genericOp.getIteratorTypes(),
1975  /*doc=*/nullptr,
1976  /*library_call=*/nullptr);
1977 
1978  // Map the block argument corresponding to the replaced argument with the
1979  // scalar constant.
1980  Region &region = genericOp->getRegion(0);
1981  Block &entryBlock = *region.begin();
1982  IRMapping mapping;
1983  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1984  scalarConstant);
1985  Region &fusedRegion = fusedOp->getRegion(0);
1986  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1987  mapping);
1988  rewriter.replaceOp(genericOp, fusedOp->getResults());
1989  return success();
1990  }
1991  return failure();
1992  }
1993 };
1994 
1995 } // namespace
1996 
1997 //===---------------------------------------------------------------------===//
1998 // Miscellaneous patterns that help fusion.
1999 //===---------------------------------------------------------------------===//
2000 
2001 namespace {
2002 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2003 /// value of the `outs` operand is not used within the op. This is only
2004 /// implemented for `linalg.generic` operations for now, but should hold for all
2005 /// linalg structured ops.
2006 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2008 
2009  LogicalResult matchAndRewrite(GenericOp op,
2010  PatternRewriter &rewriter) const override {
2011  rewriter.startOpModification(op);
2012  bool modifiedOutput = false;
2013  Location loc = op.getLoc();
2014  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2015  if (!op.payloadUsesValueFromOperand(&opOperand)) {
2016  Value operandVal = opOperand.get();
2017  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2018  if (!operandType)
2019  continue;
2020 
2021  // If outs is sparse, leave it to the sparsifier.
2023  continue;
2024 
2025  // If outs is already an `empty` operation, nothing to do.
2026  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2027  if (definingOp)
2028  continue;
2029  modifiedOutput = true;
2030  SmallVector<OpFoldResult> mixedSizes =
2031  tensor::getMixedSizes(rewriter, loc, operandVal);
2032  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2033  loc, mixedSizes, operandType.getElementType());
2034  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2035  }
2036  }
2037  if (!modifiedOutput) {
2038  rewriter.cancelOpModification(op);
2039  return failure();
2040  }
2041  rewriter.finalizeOpModification(op);
2042  return success();
2043  }
2044 };
2045 
2046 /// Fold linalg.fill into linalg.generic
2047 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2049 
2050  LogicalResult matchAndRewrite(GenericOp genericOp,
2051  PatternRewriter &rewriter) const override {
2052  if (!genericOp.hasPureTensorSemantics())
2053  return failure();
2054  bool fillFound = false;
2055  Block &payload = genericOp.getRegion().front();
2056  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2057  if (!genericOp.payloadUsesValueFromOperand(opOperand))
2058  continue;
2059  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2060  if (!fillOp)
2061  continue;
2062  fillFound = true;
2063  Value fillVal = fillOp.value();
2064  auto resultType =
2065  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2066  Value convertedVal =
2067  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2068  /*isUnsignedCast =*/false);
2069  rewriter.replaceAllUsesWith(
2070  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2071  }
2072  return success(fillFound);
2073  }
2074 };
2075 } // namespace
2076 
2078  RewritePatternSet &patterns,
2079  const ControlFusionFn &controlFoldingReshapes) {
2080  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2081  controlFoldingReshapes);
2082  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2083  controlFoldingReshapes);
2084  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2085  controlFoldingReshapes);
2086 }
2087 
2089  RewritePatternSet &patterns,
2090  const ControlFusionFn &controlFoldingReshapes) {
2091  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2092  controlFoldingReshapes);
2093  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2094  patterns.getContext(), controlFoldingReshapes);
2095 }
2096 
2098  RewritePatternSet &patterns,
2099  const ControlFusionFn &controlElementwiseOpsFusion) {
2100  auto *context = patterns.getContext();
2101  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2102  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2103  RemoveOutsDependency>(context);
2104  // Add the patterns that clean up dead operands and results.
2106 }
2107 
2109  RewritePatternSet &patterns,
2110  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2111  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2112  CollapseLinalgDimensions<linalg::CopyOp>>(
2113  patterns.getContext(), controlCollapseDimensions);
2114 }
2115 
2116 //===---------------------------------------------------------------------===//
2117 // Passes
2118 //===---------------------------------------------------------------------===//
2119 
2120 namespace {
2121 
2122 /// Pass that fuses generic ops on tensors. Used only for testing.
2123 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2124 // patterns added here heavily depends on the cost function used. Having an
2125 // opinionated pass of this form is not recommended. Deprecate this pass in
2126 // favor of test passes that check the functionality of each of the patterns
2127 // added here individually.
2128 struct LinalgElementwiseOpFusionPass
2129  : public impl::LinalgElementwiseOpFusionPassBase<
2130  LinalgElementwiseOpFusionPass> {
2131  using impl::LinalgElementwiseOpFusionPassBase<
2132  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2133  void runOnOperation() override {
2134  Operation *op = getOperation();
2135  MLIRContext *context = op->getContext();
2136  RewritePatternSet patterns(context);
2137 
2138  // Add folding with reshape by expansion patterns.
2139  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2140  Operation *producer = fusedOperand->get().getDefiningOp();
2141  return producer && producer->hasOneUse();
2142  };
2143 
2144  // Add elementwise op fusion patterns.
2145  populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2146  populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
2147 
2148  // General canonicalization patterns.
2149  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2150  GenericOp::getCanonicalizationPatterns(patterns, context);
2151  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2152  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2153  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2154  patterns);
2155 
2156  // Add constant folding patterns.
2157  populateConstantFoldLinalgOperations(patterns, defaultControlFn);
2158 
2159  // Use TopDownTraversal for compile time reasons
2160  GreedyRewriteConfig grc;
2161  grc.useTopDownTraversal = true;
2162  (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
2163  }
2164 };
2165 
2166 } // namespace
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Expanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static LogicalResult validateDynamicDimExpansion(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Checks if a single dynamic dimension expanded into multiple dynamic dimensions.
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:595
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:631
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:625
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:191
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:582
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:624
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1650
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1681
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:813
ArrayRef< int64_t > ReassociationIndicesRef
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:606
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:497
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:499
static llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer)
Returns a set of indices of the producer's results which would be preserved after the fusion.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.