MLIR  19.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
26 #include <optional>
27 #include <utility>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
31 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 
37 //===---------------------------------------------------------------------===//
38 // Methods and patterns that fuse elementwise `linalg.generic` operations.
39 //===---------------------------------------------------------------------===//
40 
41 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
42 /// the `producer` to use in the fused operation given the indexing map of the
43 /// result of the producer in the consumer.
45  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
46  AffineMap fusedConsumerArgIndexMap) {
47  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
48  // from consumer loop -> consumer arg tensor index/producer result tensor
49  // index. The fused loop is same as the consumer loop. For each producer arg
50  // the indexing map to be computed is a map from consumer loop -> producer
51  // arg tensor index.
52  // producerResultIndexMap is a map from producer loop -> tensor index.
53  // Compute the inverse to get map from tensor index -> producer loop.
54  // The inverse is a map from producer result tensor index -> producer loop.
55  AffineMap invProducerResultIndexMap =
56  inversePermutation(producerResultIndexMap);
57  assert(invProducerResultIndexMap &&
58  "expected producer result indexing map to be invertible");
59 
60  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
61  // argMap is a map from producer loop -> producer arg tensor index.
62  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
63 
64  // Compose argMap with invProducerResultIndexMap to get a map from
65  // producer result tensor index -> producer arg tensor index.
66  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
67 
68  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
69  // consumer loop/ fused loop -> producer arg tensor index.
70  return t1.compose(fusedConsumerArgIndexMap);
71 }
72 
73 /// Returns a set of indices of the producer's results which would
74 /// be preserved after the fusion.
75 llvm::SmallDenseSet<int>
77  GenericOp consumer) {
78  llvm::SmallDenseSet<int> preservedProducerResults;
79  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
80  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
81  if (producer.payloadUsesValueFromOperand(outputOperand) ||
82  !producer.canOpOperandsBeDropped(outputOperand) ||
83  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
84  return user != consumer.getOperation();
85  })) {
86  preservedProducerResults.insert(producerResult.index());
87  }
88  }
89  return preservedProducerResults;
90 }
91 
92 /// Conditions for elementwise fusion of generic operations.
94  if (!fusedOperand)
95  return false;
96 
97  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
98  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
99 
100  // Check producer and consumer are generic ops.
101  if (!producer || !consumer)
102  return false;
103 
104  // Consumer can have mixed semantics, just check operand itself has tensor
105  // type. Producer must have full tensor semantics to avoid potential
106  // aliasing between producer and consumer memrefs.
107  if (!producer.hasPureTensorSemantics() ||
108  !isa<RankedTensorType>(fusedOperand->get().getType()))
109  return false;
110 
111  // Verify that
112  // - the producer has all "parallel" iterator type.
113  if (producer.getNumParallelLoops() != producer.getNumLoops())
114  return false;
115 
116  // Only allow fusing the producer of an input operand for now.
117  // TODO: allow fusing the producer of an output operand.
118  if (!consumer.isDpsInput(fusedOperand))
119  return false;
120 
121  // Get the consumer index map. The number of results of the consumer index
122  // map must match the number of loops of the producer.
123  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
124  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
125  return false;
126 
127  // Finally the index_map for the result must be invertible. For now just
128  // verify it is a permutation.
129  AffineMap producerResultIndexMap =
130  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
131  if (!producerResultIndexMap.isPermutation())
132  return false;
133 
134  // Ensure that the fusion does not remove size information required to
135  // get the loop bounds. For non-reduction generics, this is trivially the
136  // case due to the output operand. For reductions, we need to check that after
137  // the fusion, each loop dimension has at least one input that defines it.
138  if ((consumer.getNumReductionLoops())) {
139  BitVector coveredDims(consumer.getNumLoops(), false);
140 
141  auto addToCoveredDims = [&](AffineMap map) {
142  for (auto result : map.getResults())
143  if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
144  coveredDims[dimExpr.getPosition()] = true;
145  };
146 
147  for (auto pair :
148  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
149  Value operand = std::get<0>(pair);
150  if (operand == fusedOperand->get())
151  continue;
152  AffineMap operandMap = std::get<1>(pair);
153  addToCoveredDims(operandMap);
154  }
155 
156  for (OpOperand *operand : producer.getDpsInputOperands()) {
157  AffineMap newIndexingMap =
159  operand, producerResultIndexMap, consumerIndexMap);
160  addToCoveredDims(newIndexingMap);
161  }
162  if (!coveredDims.all())
163  return false;
164  }
165 
166  return true;
167 }
168 
169 /// Generate the region of the fused tensor operation. The region of the fused
170 /// op must be empty.
172  RewriterBase &rewriter, GenericOp fusedOp,
173  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
174  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
175  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
176  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
177  // Build the region of the fused op.
178  Block &producerBlock = producer->getRegion(0).front();
179  Block &consumerBlock = consumer->getRegion(0).front();
180  OpBuilder::InsertionGuard guard(rewriter);
181  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
182  IRMapping mapper;
183 
184  // 2. Add an index operation for every fused loop dimension and use the
185  // `consumerToProducerLoopsMap` to map the producer indices.
186  if (producer.hasIndexSemantics()) {
187  // Add an index operation for every fused loop dimension.
188  unsigned numFusedOpLoops =
189  std::max(producer.getNumLoops(), consumer.getNumLoops());
190  SmallVector<Value> fusedIndices;
191  fusedIndices.reserve(numFusedOpLoops);
192  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
193  std::back_inserter(fusedIndices), [&](uint64_t dim) {
194  return rewriter.create<IndexOp>(producer.getLoc(), dim);
195  });
196  for (IndexOp indexOp :
197  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
198  Value newIndex = rewriter.create<affine::AffineApplyOp>(
199  producer.getLoc(),
200  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
201  mapper.map(indexOp.getResult(), newIndex);
202  }
203  }
204  // TODO: allow fusing the producer of an output operand.
205  assert(consumer.isDpsInput(fusedOperand) &&
206  "expected producer of input operand");
207  // 3. Consumer input operands up to consumerIdx (exclusive).
208  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
209  fusedOperand->getOperandNumber())) // input assumption.
210  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
211 
212  // Replacing consumerIdx requires getting the cloned, yielded, value from
213  // the (cloned) producer block. This happens in step 9.
214 
215  // 4. Splice in producer's input operands.
216  for (BlockArgument bbArg :
217  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
218  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
219 
220  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
221  for (BlockArgument bbArg :
222  consumerBlock.getArguments()
223  .take_front(consumer.getNumDpsInputs())
224  .drop_front(fusedOperand->getOperandNumber() + 1))
225  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
226 
227  // 6. All of the producer's output operands
228  for (const auto &bbArg : llvm::enumerate(
229  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
230  if (!preservedProducerResults.count(bbArg.index()))
231  continue;
232  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
233  bbArg.value().getLoc()));
234  }
235 
236  // 7. All of consumer's output operands.
237  for (BlockArgument bbArg :
238  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
239  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
240 
241  // 8. Clone all producer operations except for the yield and index operations
242  // to the fused operation.
243  for (auto &op : producerBlock.without_terminator()) {
244  if (!isa<IndexOp>(op))
245  rewriter.clone(op, mapper);
246  }
247  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
248  // forward the yield operand.
249  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
250  unsigned producerResultNumber =
251  cast<OpResult>(fusedOperand->get()).getResultNumber();
252  Value replacement =
253  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
254 
255  // Sanity checks, if replacement is not already in the mapper then it must be
256  // produced outside.
257  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
258  if (auto bb = dyn_cast<BlockArgument>(replacement))
259  assert(bb.getOwner() != &producerBlock &&
260  "yielded block argument must have been mapped");
261  else
262  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
263  "yielded value must have been mapped");
264  }
265  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
266  replacement);
267  // 10. Clone operations from the consumer to the fused op.
268  for (auto &op : consumerBlock.without_terminator())
269  rewriter.clone(op, mapper);
270 
271  // 11. Include the final yield (which is the remapped values for all the
272  // yield)
273  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
274  SmallVector<Value> fusedYieldValues;
275  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
276  consumerYieldOp.getNumOperands());
277  for (const auto &producerYieldVal :
278  llvm::enumerate(producerYieldOp.getOperands())) {
279  if (preservedProducerResults.count(producerYieldVal.index()))
280  fusedYieldValues.push_back(
281  mapper.lookupOrDefault(producerYieldVal.value()));
282  }
283  for (auto consumerYieldVal : consumerYieldOp.getOperands())
284  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
285  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
286 
287  // Sanity checks.
288  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
289  "Ill-formed GenericOp region");
290 }
291 
294  OpOperand *fusedOperand) {
295  assert(areElementwiseOpsFusable(fusedOperand) &&
296  "expected elementwise operation pre-conditions to pass");
297  auto producerResult = cast<OpResult>(fusedOperand->get());
298  auto producer = cast<GenericOp>(producerResult.getOwner());
299  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
300  // TODO: allow fusing the producer of an output operand.
301  assert(consumer.isDpsInput(fusedOperand) &&
302  "expected producer of input operand");
303  /// Find the results of the producer that have uses outside of the consumer.
304  llvm::SmallDenseSet<int> preservedProducerResults =
306  consumer);
307 
308  // Compute the fused operands list and indexing maps.
309  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
310  SmallVector<Type> fusedResultTypes;
311  SmallVector<AffineMap> fusedIndexMaps;
312  fusedInputOperands.reserve(producer.getNumDpsInputs() +
313  consumer.getNumDpsInputs());
314  fusedOutputOperands.reserve(preservedProducerResults.size() +
315  consumer.getNumDpsInits());
316  fusedResultTypes.reserve(preservedProducerResults.size() +
317  consumer.getNumDpsInits());
318  fusedIndexMaps.reserve(producer->getNumOperands() +
319  consumer->getNumOperands());
320  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
321  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
322  auto consumerInputs = consumer.getDpsInputOperands();
323  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
324  return operand == fusedOperand;
325  });
326  assert(it != consumerInputs.end() && "expected to find the consumer operand");
327  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
328  fusedInputOperands.push_back(opOperand->get());
329  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
330  }
331  // 4. Splice in producer's input operands/maps.
332  AffineMap producerResultIndexMap =
333  producer.getIndexingMapMatchingResult(producerResult);
334  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
335  fusedInputOperands.push_back(opOperand->get());
336  // Compute indexing maps for the producer args in the fused operation.
338  opOperand, producerResultIndexMap,
339  consumer.getMatchingIndexingMap(fusedOperand));
340  fusedIndexMaps.push_back(map);
341  }
342  // 5. Remaining consumer's input operands/maps (drop past index
343  // `consumerIdx`).
344  for (OpOperand *opOperand :
345  llvm::make_range(std::next(it), consumerInputs.end())) {
346  fusedInputOperands.push_back(opOperand->get());
347  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
348  }
349 
350  // 6. Collect all of the producer outputs.
351  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
352  if (!preservedProducerResults.count(opOperand.index()))
353  continue;
354 
355  fusedOutputOperands.push_back(opOperand.value().get());
357  &opOperand.value(), producerResultIndexMap,
358  consumer.getMatchingIndexingMap(fusedOperand));
359  fusedIndexMaps.push_back(map);
360  fusedResultTypes.push_back(opOperand.value().get().getType());
361  }
362 
363  // 7. All of consumer's output operands (skip operands: added by the builder).
364  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
365  fusedOutputOperands.push_back(opOperand.get());
366  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
367  Type resultType = opOperand.get().getType();
368  if (!isa<MemRefType>(resultType))
369  fusedResultTypes.push_back(resultType);
370  }
371 
372  // Generate the fused op.
373  auto fusedOp = rewriter.create<GenericOp>(
374  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
375  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
376  consumer.getIteratorTypes(),
377  /*doc=*/nullptr,
378  /*library_call=*/nullptr);
379  if (!fusedOp.getShapesToLoopsMap()) {
380  // Fused op has invalid indexing maps. Typically this means something is off
381  // in the input, but going ahead here would result in verification errors.
382  // So cleanup and abort.
383  rewriter.eraseOp(fusedOp);
384  return rewriter.notifyMatchFailure(
385  fusedOp, "fused op failed loop bound computation check");
386  }
387 
388  // Construct an AffineMap from consumer loops to producer loops.
389  // consumer loop -> tensor index
390  AffineMap consumerResultIndexMap =
391  consumer.getMatchingIndexingMap(fusedOperand);
392  // tensor index -> producer loop
393  AffineMap invProducerResultIndexMap =
394  inversePermutation(producerResultIndexMap);
395  assert(invProducerResultIndexMap &&
396  "expected producer result indexig map to be invertible");
397  // consumer loop -> producer loop
398  AffineMap consumerToProducerLoopsMap =
399  invProducerResultIndexMap.compose(consumerResultIndexMap);
400 
402  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
403  consumer.getNumLoops(), preservedProducerResults);
405  result.fusedOp = fusedOp;
406  int resultNum = 0;
407  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
408  if (preservedProducerResults.count(index))
409  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
410  for (auto consumerResult : consumer->getResults())
411  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
412  return result;
413 }
414 
415 namespace {
416 /// Patterns to fuse a generic op, with the producer of its operands.
417 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
418 public:
419  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
420  PatternBenefit benefit = 1)
421  : OpRewritePattern<GenericOp>(context, benefit),
422  controlFn(std::move(fun)) {}
423 
424  LogicalResult matchAndRewrite(GenericOp genericOp,
425  PatternRewriter &rewriter) const override {
426  // Find the first operand that is defined by another generic op on tensors.
427  for (OpOperand &opOperand : genericOp->getOpOperands()) {
428  if (!areElementwiseOpsFusable(&opOperand))
429  continue;
430  if (!controlFn(&opOperand))
431  continue;
432 
433  Operation *producer = opOperand.get().getDefiningOp();
434 
435  // Do not fuse a sparse-in/dense-out operation, as the
436  // result is too often not sparsifiable anymore.
437  if (sparse_tensor::hasAnySparseOperand(producer) &&
439  return failure();
440 
441  // Find the producer of the operand.
443  fuseElementwiseOps(rewriter, &opOperand);
444  if (failed(fusionResult))
445  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
446 
447  // Perform the fusion.
448  for (auto [origVal, replacement] : fusionResult->replacements) {
449  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
450  // Only replace consumer uses.
451  return use.get().getDefiningOp() != producer;
452  });
453  }
454  rewriter.eraseOp(genericOp);
455  return success();
456  }
457  return failure();
458  }
459 
460 private:
461  ControlFusionFn controlFn;
462 };
463 } // namespace
464 
465 //===---------------------------------------------------------------------===//
466 // Methods and patterns that fuse reshape ops with elementwise operations by
467 // expanding the dimensionality of the elementwise operations.
468 //===---------------------------------------------------------------------===//
469 
470 /// Conditions for folding a generic operation with a reshape op by expanding
471 /// the iteration space dimensionality for tensor operations. These are
472 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
473 /// following fusion pattern.
474 ///
475 /// Consider
476 ///
477 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
478 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
479 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
480 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
481 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
482 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
483 ///
484 /// The reshape can be folded into the `genericOp` if its loop dimensionality
485 /// is increased to match the result (operand) of the tensor.expand_shape.
486 /// The indexing_map of the fused tensor in the `genericOp` and the
487 /// reassociation map helps compute the indexing maps of the modified op.
488 /// For the above example, based on the reassociation map it
489 /// can be concluded that
490 ///
491 /// - The loop used to access the first dimension of the fused tensor is split
492 /// into two.
493 /// - The loop used to access the second dimension of the fused tensor is kept
494 /// as is.
495 /// - The loop used to access the third dimension of the fused tensor is split
496 /// into three.
497 ///
498 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
499 /// op, then
500 ///
501 /// d0 -> e0, e1
502 /// d1 -> e2, e3, e4
503 /// d2 -> e5
504 ///
505 /// substituting this, the generic op can be rewritten as
506 ///
507 /// %d = linalg.generic ins(%0, %1 : )
508 /// indexing_maps =
509 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
510 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
511 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
512 ///
513 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
514 /// to make it consistent
515 ///
516 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
517 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
518 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
519 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
520 ///
521 /// The added reshapes are again expanding patterns, so they will get fused
522 /// with its producers if possible.
523 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
524  OpOperand *fusableOpOperand) {
525  // Is fusable only if:
526  // - All the indexing maps for operands and results are projected
527  // permutations.
528  // - The fused tensor is not a scalar.
529  // - All the loops are parallel loops.
530  return genericOp.hasPureTensorSemantics() &&
531  llvm::all_of(genericOp.getIndexingMaps().getValue(),
532  [](Attribute attr) {
533  return cast<AffineMapAttr>(attr)
534  .getValue()
535  .isProjectedPermutation();
536  }) &&
537  genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
538  0 &&
539  llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
540 }
541 
542 namespace {
543 /// Information needed to expand a generic operation to fold the reshape with
544 /// it.
545 class ExpansionInfo {
546 public:
547  // Computes the mapping from original dimensions of the op to the dimensions
548  // of the expanded op given the `indexingMap` of the fused operand/result of
549  // the generic op, the `reassocationMaps` of the reshape op and the shape of
550  // the expanded op.
551  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
552  ArrayRef<AffineMap> reassociationMaps,
553  ArrayRef<int64_t> expandedShape,
554  ArrayRef<int64_t> collapsedShape,
555  PatternRewriter &rewriter);
556  unsigned getOrigOpNumDims() const { return reassociation.size(); }
557  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
558  ReassociationIndicesRef getExpandedDims(unsigned i) const {
559  return reassociation[i];
560  }
561  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
562  return expandedShapeMap[i];
563  }
564  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
565 
566 private:
567  /// Reassociation from the dimensions in the original operation to the
568  /// dimension of the expanded operation.
569  SmallVector<ReassociationIndices> reassociation;
570  /// Mapping from extent of loops in the original operation, to the extent of
571  /// loops in the expanded operation.
572  SmallVector<SmallVector<int64_t>> expandedShapeMap;
573  /// Extent of the loop in the original operation.
574  SmallVector<int64_t> originalLoopExtent;
575  unsigned expandedOpNumDims;
576 };
577 } // namespace
578 
579 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
580  OpOperand *fusableOpOperand,
581  ArrayRef<AffineMap> reassociationMaps,
582  ArrayRef<int64_t> expandedShape,
583  ArrayRef<int64_t> collapsedShape,
584  PatternRewriter &rewriter) {
585  if (reassociationMaps.empty())
586  return failure();
587  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
588 
589  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
590  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
591 
592  reassociation.clear();
593  expandedShapeMap.clear();
594  // Compute the number of dimension in the expanded op that correspond to each
595  // dimension of the original op.
596  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
597  expandedShapeMap.resize(fusedIndexMap.getNumDims());
598  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
599  unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
600  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
601  numExpandedDims[pos] = foldedDims.getNumResults();
602  ArrayRef<int64_t> shape =
603  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
604  expandedShapeMap[pos].assign(shape.begin(), shape.end());
605  }
606  // The remaining dimensions remain the same.
607  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
608  if (expandedShapeMap[i].empty())
609  expandedShapeMap[i] = {originalLoopExtent[i]};
610 
611  // Compute reassociation map from the original op to the expanded op.
612  unsigned sum = 0;
613  reassociation.reserve(fusedIndexMap.getNumDims());
614  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
615  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
616  reassociation.emplace_back(seq.begin(), seq.end());
617  sum += numFoldedDim.value();
618  }
619  expandedOpNumDims = sum;
620  return success();
621 }
622 
623 /// Epanding the body of a linalg operation requires adaptations of the accessed
624 /// loop indices. Specifically, access of indices in the original operation need
625 /// to be replaced with linearizations of indices in the expanded op. That
626 /// requires the shape of the expanded dimensions to be static (at least all but
627 /// the most significant). For now check that these are all statically sized.
628 /// Note that this could be extended to handle dynamic case, but the
629 /// implementation below uses `affine.apply` which seems to have issues when the
630 /// shapes are not static.
631 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
632  const ExpansionInfo &expansionInfo,
633  PatternRewriter &rewriter) {
634  if (!genericOp.hasIndexSemantics())
635  return success();
636  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
637  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
638  if (expandedShape.size() == 1)
639  continue;
640  for (int64_t shape : expandedShape.drop_front()) {
641  if (ShapedType::isDynamic(shape)) {
642  return rewriter.notifyMatchFailure(
643  genericOp, "cannot expand due to index semantics and dynamic dims");
644  }
645  }
646  }
647  return success();
648 }
649 
650 /// Return the indexing map to use in the expanded op for a given the
651 /// `indexingMap` of the original operation.
652 static AffineMap
654  const ExpansionInfo &expansionInfo) {
655  SmallVector<AffineExpr> newExprs;
656  for (AffineExpr expr : indexingMap.getResults()) {
657  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
658  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
659  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
660  return builder.getAffineDimExpr(static_cast<unsigned>(v));
661  }));
662  newExprs.append(expandedExprs.begin(), expandedExprs.end());
663  }
664  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
665  indexingMap.getNumSymbols(), newExprs,
666  builder.getContext());
667 }
668 
669 /// Return the type of the operand/result to use in the expanded op given the
670 /// type in the original op.
671 static RankedTensorType getExpandedType(RankedTensorType originalType,
672  AffineMap indexingMap,
673  const ExpansionInfo &expansionInfo) {
674  SmallVector<int64_t> expandedShape;
675  for (AffineExpr expr : indexingMap.getResults()) {
676  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
677  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
678  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
679  }
680  return RankedTensorType::get(expandedShape, originalType.getElementType());
681 }
682 
683 /// Returns the reassociation maps to use in the `tensor.expand_shape`
684 /// operation to convert the operands of the original operation to operands of
685 /// the expanded operation. The same method is used to compute the
686 /// `tensor.collapse_shape` used to collapse the result of the expanded
687 /// op to get the value that can replace all uses of the results of the original
688 /// op.
691  const ExpansionInfo &expansionInfo) {
692  SmallVector<ReassociationIndices> reassociation;
693  unsigned numReshapeDims = 0;
694  for (AffineExpr expr : indexingMap.getResults()) {
695  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
696  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
697  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
698  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
699  reassociation.emplace_back(std::move(indices));
700  numReshapeDims += numExpandedDims;
701  }
702  return reassociation;
703 }
704 
705 /// Update the body of an expanded linalg operation having index semantics. The
706 /// indices of the original operation need to be recovered by linearizing the
707 /// indices of the correspoding dimensions of the expanded operation. For now it
708 /// is assumed that the shapes of the expanded operation needed for
709 /// linearization are static.
711  Location loc, Region &fusedRegion,
712  const ExpansionInfo &expansionInfo) {
713  // Replace the original indices by the linearization of the expanded indices.
714  for (IndexOp indexOp :
715  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
716  ArrayRef<int64_t> expandedDims =
717  expansionInfo.getExpandedDims(indexOp.getDim());
718  assert(!expandedDims.empty() && "expected valid expansion info");
719 
720  // Skip index operations that are not affected by the expansion.
721  if (expandedDims.size() == 1 &&
722  expandedDims.front() == (int64_t)indexOp.getDim())
723  continue;
724 
725  // Linearize the expanded indices of the original index dimension.
726  OpBuilder::InsertionGuard guard(rewriter);
727  rewriter.setInsertionPointAfter(indexOp);
728  ArrayRef<int64_t> expandedDimsShape =
729  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
730  SmallVector<Value> expandedIndices;
731  expandedIndices.reserve(expandedDims.size() - 1);
732  llvm::transform(
733  expandedDims.drop_front(), std::back_inserter(expandedIndices),
734  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
735  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
736  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
737  assert(!ShapedType::isDynamic(std::get<0>(it)));
738  AffineExpr idx, acc;
739  bindDims(rewriter.getContext(), idx, acc);
740  newIndex = rewriter.create<affine::AffineApplyOp>(
741  indexOp.getLoc(), idx + acc * std::get<0>(it),
742  ValueRange{std::get<1>(it), newIndex});
743  }
744  rewriter.replaceOp(indexOp, newIndex);
745  }
746 }
747 
748 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
749 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
750 /// that those conditions have been satisfied.
751 static std::optional<SmallVector<Value>>
752 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
753  OpOperand *fusableOpOperand,
754  PatternRewriter &rewriter) {
755  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
756  "preconditions for fuse operation failed");
757  // Check if reshape is expanding or collapsing.
758  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
759  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
760  bool isExpanding = (expandingReshapeOp != nullptr);
761  RankedTensorType expandedType = isExpanding
762  ? expandingReshapeOp.getResultType()
763  : collapsingReshapeOp.getSrcType();
764  RankedTensorType collapsedType = isExpanding
765  ? expandingReshapeOp.getSrcType()
766  : collapsingReshapeOp.getResultType();
767 
768  ExpansionInfo expansionInfo;
769  if (failed(expansionInfo.compute(
770  genericOp, fusableOpOperand,
771  isExpanding ? expandingReshapeOp.getReassociationMaps()
772  : collapsingReshapeOp.getReassociationMaps(),
773  expandedType.getShape(), collapsedType.getShape(), rewriter)))
774  return std::nullopt;
775 
776  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
777  return std::nullopt;
778 
779  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
780  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
781  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
782  }));
783 
784  // Set insertion point to the generic op.
785  OpBuilder::InsertionGuard g(rewriter);
786  rewriter.setInsertionPoint(genericOp);
787 
788  SmallVector<Value> expandedOpOperands;
789  expandedOpOperands.reserve(genericOp.getNumDpsInputs());
790  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
791  if (opOperand == fusableOpOperand) {
792  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
793  : collapsingReshapeOp.getSrc());
794  continue;
795  }
796  if (auto opOperandType =
797  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
798  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
799  RankedTensorType expandedOperandType =
800  getExpandedType(opOperandType, indexingMap, expansionInfo);
801  if (expandedOperandType != opOperand->get().getType()) {
802  // Reshape the operand to get the right type.
803  SmallVector<ReassociationIndices> reassociation =
804  getReassociationForExpansion(indexingMap, expansionInfo);
806  [&](const Twine &msg) {
807  return rewriter.notifyMatchFailure(genericOp, msg);
808  },
809  opOperandType.getShape(), expandedOperandType.getShape(),
810  reassociation,
811  /*isExpandingReshape=*/true)))
812  return std::nullopt;
813  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
814  genericOp.getLoc(), expandedOperandType, opOperand->get(),
815  reassociation));
816  continue;
817  }
818  }
819  expandedOpOperands.push_back(opOperand->get());
820  }
821 
822  Location loc = genericOp.getLoc();
823  SmallVector<Value> outputs;
824  for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
825  AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
826  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
827  RankedTensorType expandedOutputType =
828  getExpandedType(opOperandType, indexingMap, expansionInfo);
829  if (expandedOutputType != opOperand.get().getType()) {
830  SmallVector<ReassociationIndices> reassociation =
831  getReassociationForExpansion(indexingMap, expansionInfo);
833  [&](const Twine &msg) {
834  return rewriter.notifyMatchFailure(genericOp, msg);
835  },
836  opOperandType.getShape(), expandedOutputType.getShape(),
837  reassociation,
838  /*isExpandingReshape=*/true)))
839  return std::nullopt;
840  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
841  genericOp.getLoc(), expandedOutputType, opOperand.get(),
842  reassociation));
843  } else {
844  outputs.push_back(opOperand.get());
845  }
846  }
847 
848  // The iterator types of the expanded op are all parallel.
849  SmallVector<utils::IteratorType> iteratorTypes(
850  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
851 
852  TypeRange resultTypes = ValueRange(outputs).getTypes();
853  auto fusedOp =
854  rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
855  /*inputs=*/expandedOpOperands, outputs,
856  expandedOpIndexingMaps, iteratorTypes);
857  Region &fusedRegion = fusedOp->getRegion(0);
858  Region &originalRegion = genericOp->getRegion(0);
859  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
860 
861  // Update the index accesses after the expansion.
862  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
863 
864  // Reshape the result values to their original shape if this is a collapsing
865  // reshape folded into its consumer.
866  SmallVector<Value> resultVals;
867  for (OpResult opResult : genericOp->getOpResults()) {
868  int64_t resultNumber = opResult.getResultNumber();
869  if (resultTypes[resultNumber] != opResult.getType()) {
870  SmallVector<ReassociationIndices> reassociation =
872  genericOp.getMatchingIndexingMap(
873  genericOp.getDpsInitOperand(resultNumber)),
874  expansionInfo);
875  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
876  genericOp.getLoc(), opResult.getType(),
877  fusedOp->getResult(resultNumber), reassociation));
878  } else {
879  resultVals.push_back(fusedOp->getResult(resultNumber));
880  }
881  }
882  // Assuming a single result.
883  return resultVals;
884 }
885 
886 namespace {
887 
888 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
889 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
890 /// in the consumer is expanded.
891 class FoldWithProducerReshapeOpByExpansion
892  : public OpRewritePattern<GenericOp> {
893 public:
894  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
895  ControlFusionFn foldReshapes,
896  PatternBenefit benefit = 1)
897  : OpRewritePattern<GenericOp>(context, benefit),
898  controlFoldingReshapes(std::move(foldReshapes)) {}
899 
900  LogicalResult matchAndRewrite(GenericOp genericOp,
901  PatternRewriter &rewriter) const override {
902  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
903  tensor::CollapseShapeOp reshapeOp =
904  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
905  if (!reshapeOp)
906  continue;
907  // Fold only if
908  // - The tensor reshape op is folding.
909  // - All constraints of fusing with reshape by expansion are met.
910  if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
911  (!controlFoldingReshapes(opOperand)))
912  continue;
913 
914  std::optional<SmallVector<Value>> replacementValues =
915  fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
916  if (!replacementValues)
917  return failure();
918  rewriter.replaceOp(genericOp, *replacementValues);
919  return success();
920  }
921  return failure();
922  }
923 
924 private:
925  ControlFusionFn controlFoldingReshapes;
926 };
927 
928 /// Pattern to fold a tensor.expand_shape op with its producer generic op
929 /// by expanding the dimensionality of the loop in the producer op.
930 struct FoldReshapeWithGenericOpByExpansion
931  : public OpRewritePattern<tensor::ExpandShapeOp> {
932 
933  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
934  ControlFusionFn foldReshapes,
935  PatternBenefit benefit = 1)
936  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
937  controlFoldingReshapes(std::move(foldReshapes)) {}
938 
939  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
940  PatternRewriter &rewriter) const override {
941  // Fold only if all constraints of fusing with reshape by expansion are met.
942  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
943  if (!producerResult) {
944  return rewriter.notifyMatchFailure(reshapeOp,
945  "source not produced by an operation");
946  }
947 
948  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
949  if (!producer) {
950  return rewriter.notifyMatchFailure(reshapeOp,
951  "producer not a generic op");
952  }
953 
955  producer,
956  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
957  return rewriter.notifyMatchFailure(
958  reshapeOp, "failed preconditions of fusion with producer generic op");
959  }
960 
961  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
962  return rewriter.notifyMatchFailure(reshapeOp,
963  "fusion blocked by control function");
964  }
965 
966  std::optional<SmallVector<Value>> replacementValues =
968  producer, reshapeOp,
969  producer.getDpsInitOperand(producerResult.getResultNumber()),
970  rewriter);
971  if (!replacementValues) {
972  return rewriter.notifyMatchFailure(reshapeOp,
973  "fusion by expansion failed");
974  }
975 
976  // Find the replacement for the reshape op. Since the replacements have the
977  // same type as the returns of the original generic op, the consumer reshape
978  // op can be replaced by the source of the collapse_shape op that defines
979  // the replacement.
980  Value reshapeReplacement =
981  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
982  .getResultNumber()];
983  if (auto collapseOp =
984  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
985  reshapeReplacement = collapseOp.getSrc();
986  }
987  rewriter.replaceOp(reshapeOp, reshapeReplacement);
988  rewriter.replaceOp(producer, *replacementValues);
989  return success();
990  }
991 
992 private:
993  ControlFusionFn controlFoldingReshapes;
994 };
995 } // namespace
996 
997 //===---------------------------------------------------------------------===//
998 // Methods and patterns to fuse reshape with linalg.generic operations by
999 // contraction of dimensions.
1000 //===---------------------------------------------------------------------===//
1001 
1002 /// For a given list of indices in the range of the `indexingMap` that are
1003 /// folded, return the indices of the corresponding domain. Return
1004 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1005 /// reassociation are distinct.
1006 static ReassociationIndices
1008  ReassociationIndicesRef rangeReassociation) {
1009  assert(indexingMap.isProjectedPermutation() &&
1010  "expected projected permutation");
1011 
1012  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1013  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1014  return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1015  }));
1016  // The projected permutation semantics ensures that there is no repetition of
1017  // the domain indices.
1018  return domainReassociation;
1019 }
1020 
1021 /// For a given `dimSequence`, check if the sequence is conserved in the
1022 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1023 /// Non-existence of the sequence returns true as well.
1025  ReassociationIndicesRef dimSequence) {
1026  assert(!dimSequence.empty() &&
1027  "expected non-empty list for dimension sequence");
1028  assert(indexingMap.isProjectedPermutation() &&
1029  "expected indexing map to be projected permutation");
1030 
1031  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1032  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1033 
1034  unsigned dimSequenceStart = dimSequence[0];
1035  for (const auto &expr : enumerate(indexingMap.getResults())) {
1036  unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1037  // 1. Check if this start of the sequence.
1038  if (dimInMapStart == dimSequenceStart) {
1039  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1040  return false;
1041  // 1a. Check if sequence is preserved.
1042  for (const auto &dimInSequence : enumerate(dimSequence)) {
1043  unsigned dimInMap =
1044  cast<AffineDimExpr>(
1045  indexingMap.getResult(expr.index() + dimInSequence.index()))
1046  .getPosition();
1047  if (dimInMap != dimInSequence.value())
1048  return false;
1049  }
1050  // Found the sequence. Projected permutation
1051  // enforces that all AffineDimExprs in the result are unique, so no
1052  // further checks are needed.
1053  return true;
1054  }
1055  // 2. If position in the expr (which is of type AffineDimExpr) is part
1056  // of sequence, return false here. This implies the entire sequence does not
1057  // exist in the indexing map.
1058  if (sequenceElements.count(dimInMapStart))
1059  return false;
1060  }
1061  // 3. No element of sequence found. Return true.
1062  return true;
1063 }
1064 
1067  return llvm::all_of(maps, [&](AffineMap map) {
1068  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1069  return isDimSequencePreserved(map, dimSequence);
1070  });
1071  });
1072 }
1073 
1074 // Return the list of dimensions of the iteration domain that can be
1075 // collapsed to allow for fusion with the a producer that is an expand_shape
1076 // operation. If all dimensions created by expansion can be collapsed in the
1077 // iteration space then the reshape is defunct.
1078 //
1079 // Example:
1080 //
1081 // ```mlir
1082 // #map = affine_map<(d0, d1) -> (d0, d1)>
1083 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1084 // %2 = tensor.empty [..] : tensor<?x4xf32>
1085 // %3 = linalg.generic {
1086 // indexing_maps = [#map, #map],
1087 // iterator_types = ["parallel" ,"parallel"]}
1088 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1089 // ```
1090 //
1091 // can be fused by collapsing the dimensions of the iteration space.
1092 //
1093 // ```mlir
1094 // #map = affine_map<(d0) -> (d0)>
1095 // %2 = tensor.empty [..] : tensor<?xf32>
1096 // %3 = linalg.generic {
1097 // indexing_maps = [#map, #map],
1098 // iterator_types = ["parallel"]}
1099 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1100 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1101 // ```
1102 //
1103 // In the following example,
1104 //
1105 // ```mlir
1106 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1107 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1108 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1109 // %2 = tensor.empty [..] : tensor<4x?xf32>
1110 // %2 = linalg.generic {
1111 // indexing_maps = [#map0, #map1],
1112 // iterator_types = ["parallel" ,"parallel"]}
1113 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1114 // ```
1115 //
1116 // the reshape cannot be fused with the generic op by collapsing the op
1117 // dimensions since the indexing maps will have to contain mods and divs
1118 // to preserve the accesses pattern. When no dimensions of the iteration
1119 // space are collapsable and empty vector is returned.
1121 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1122  ArrayRef<ReassociationIndices> reassociation) {
1123  // Some basic checks for this fusion to be valid.
1124  if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
1125  return {};
1126 
1127  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1128  return map.isProjectedPermutation();
1129  })) {
1130  return {};
1131  }
1132 
1133  // Compute all the loops with the reduction iterator types.
1134  SmallVector<unsigned> reductionDims;
1135  genericOp.getReductionDims(reductionDims);
1136 
1137  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1138  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1139  auto iteratorTypes = genericOp.getIteratorTypesArray();
1140  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1141  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1142  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1143 
1144  // Ignore dims that are not folded.
1145  if (foldedRangeDims.size() == 1)
1146  continue;
1147 
1148  ReassociationIndices foldedIterationSpaceDims =
1149  getDomainReassociation(indexingMap, foldedRangeDims);
1150 
1151  // Check that the folded iteration dims do not contain already processed
1152  // dims.
1153  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1154  return processedIterationDims.count(dim);
1155  }))
1156  continue;
1157 
1158  // Check that all folded iterator types are all parallel or all reductions.
1159  utils::IteratorType startIteratorType =
1160  iteratorTypes[foldedIterationSpaceDims[0]];
1161  if (!isParallelIterator(startIteratorType) &&
1162  !isReductionIterator(startIteratorType))
1163  continue;
1164  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1165  return iteratorTypes[dim] != startIteratorType;
1166  }))
1167  continue;
1168 
1169  // If the folded dimensions correspond to a "reduction" iterator type,
1170  // the folded dimensions need to be "in-order". Strictly speaking this is
1171  // not necessary, for reductions that are associative and commutative, but
1172  // using a more strict definition of reduction for now.
1173  if (isReductionIterator(startIteratorType)) {
1174  bool isContiguous = false;
1175  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1176  // Move window in `reductionDims` to start of the folded iteration dims.
1177  if (startDim.value() != foldedIterationSpaceDims[0])
1178  continue;
1179  // If sizes doesnt match, trivial not contiguous. This condition should
1180  // not be hit.
1181  if (startDim.index() + foldedIterationSpaceDims.size() >
1182  reductionDims.size())
1183  break;
1184  // Check that the contiguity is maintained.
1185  isContiguous = true;
1186  for (const auto &foldedDim :
1187  llvm::enumerate(foldedIterationSpaceDims)) {
1188  if (reductionDims[foldedDim.index() + startDim.index()] !=
1189  foldedDim.value()) {
1190  isContiguous = false;
1191  break;
1192  }
1193  }
1194  break;
1195  }
1196  if (!isContiguous)
1197  continue;
1198  }
1199 
1200  // Check that the sequence is preserved in all indexing maps.
1201  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1202  [&](AffineMap indexingMap) {
1203  return !isDimSequencePreserved(indexingMap,
1204  foldedIterationSpaceDims);
1205  }))
1206  continue;
1207 
1208  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1209  foldedIterationSpaceDims.end());
1210  iterationSpaceReassociation.emplace_back(
1211  std::move(foldedIterationSpaceDims));
1212  }
1213 
1214  return iterationSpaceReassociation;
1215 }
1216 
1217 /// Helper class to carry state while collapsing the `linalg.generic` op.
1218 namespace {
1219 class CollapsingInfo {
1220 public:
1221  LogicalResult initialize(unsigned origNumLoops,
1222  ArrayRef<ReassociationIndices> foldedIterationDims) {
1223  llvm::SmallDenseSet<int64_t, 4> processedDims;
1224  // Find all the dims that are folded.
1225  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1226  if (foldedIterationDim.empty())
1227  continue;
1228  // If the folded dims contain dims already folded, that's illegal
1229  // specification. Repetition within a list is also illegal.
1230  for (auto dim : foldedIterationDim) {
1231  if (dim >= origNumLoops)
1232  return failure();
1233  if (processedDims.count(dim))
1234  return failure();
1235  processedDims.insert(dim);
1236  }
1237  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1238  foldedIterationDim.end());
1239  }
1240  if (processedDims.size() > origNumLoops)
1241  return failure();
1242 
1243  // Add all the preserved dims of the original op as single
1244  // elements to `collapsedOpToOrigOpIterationDim`.
1245  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1246  if (processedDims.count(dim))
1247  continue;
1248  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1249  }
1250 
1251  llvm::sort(collapsedOpToOrigOpIterationDim,
1253  return lhs[0] < rhs[0];
1254  });
1255  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1256  for (const auto &foldedDims :
1257  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1258  for (const auto &dim : enumerate(foldedDims.value()))
1259  origOpToCollapsedOpIterationDim[dim.value()] =
1260  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1261  }
1262  return success();
1263  }
1264 
1265  /// Return mapping from collapsed loop domain to original loop domain.
1266  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1267  return collapsedOpToOrigOpIterationDim;
1268  }
1269 
1270  /// Return mapping from original loop domain to collapsed loop domain. The
1271  /// mapping is a pair. First value is the dimension in the collapsed loop that
1272  /// the original loop is mapped to. Second is the relative position in folded
1273  /// list of this domain. For example if the original loop domain is 3D, and
1274  /// the collapsed loop domain is folding all of it, i.e.
1275  ///
1276  /// ```
1277  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1278  /// ```
1279  ///
1280  /// then
1281  ///
1282  /// ```
1283  /// origOpToCollapsedOpMapping[0] = {0, 0};
1284  /// origOpToCollapsedOpMapping[1] = {0, 1};
1285  /// origOpToCollapsedOpMapping[2] = {0, 2};
1286  /// origOpToCollapsedOpMapping[3] = {1, 0};
1287  /// origOpToCollapsedOpMapping[4] = {1, 1};
1288  /// ```
1289  ///
1290  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1291  return origOpToCollapsedOpIterationDim;
1292  }
1293 
1294  /// Return the collapsed op iteration domain rank.
1295  unsigned getCollapsedOpIterationRank() const {
1296  return collapsedOpToOrigOpIterationDim.size();
1297  }
1298 
1299 private:
1300  /// Map from the iteration domain index in collapsed op to the iteration
1301  /// domain indices in the original op.
1302  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1303 
1304  /// Map from iteration domain index in the original op to the iteration domain
1305  /// index in the collapsed op.
1306  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1307 };
1308 } // namespace
1309 
1310 /// Get the iterator types for the collapsed operation given the original
1311 /// iterator types and collapsed dimensions.
1314  const CollapsingInfo &collapsingInfo) {
1315  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1316  for (ReassociationIndicesRef foldedIterDims :
1317  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1318  assert(!foldedIterDims.empty() &&
1319  "reassociation indices expected to have non-empty sets");
1320  // Just pick the iterator type of the first folded dim. Pre-condition checks
1321  // expected to have checked that iterator types of all folded dimensions are
1322  // the same.
1323  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1324  }
1325  return collapsedIteratorTypes;
1326 }
1327 
1328 /// Compute the indexing map in the collapsed op that corresponds to the given
1329 /// `indexingMap` of the original operation.
1330 static AffineMap
1332  const CollapsingInfo &collapsingInfo) {
1333  MLIRContext *context = indexingMap.getContext();
1334  assert(indexingMap.isProjectedPermutation() &&
1335  "expected indexing map to be projected permutation");
1336  SmallVector<AffineExpr> resultExprs;
1337  auto origOpToCollapsedOpMapping =
1338  collapsingInfo.getOrigOpToCollapsedOpMapping();
1339  for (auto expr : indexingMap.getResults()) {
1340  unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1341  // If the dim is not the first of the collapsed dim, do nothing.
1342  if (origOpToCollapsedOpMapping[dim].second != 0)
1343  continue;
1344  // The next n-dims are guaranteed to be collapsed. So just use the
1345  // iteration dimension of the collapsed op.
1346  resultExprs.push_back(
1347  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1348  }
1349  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1350  resultExprs, context);
1351 }
1352 
1353 /// Return the `reassociation` indices to use to collapse the operand when the
1354 /// iteration space of a generic op is collapsed.
1357  const CollapsingInfo &collapsingInfo) {
1358  unsigned counter = 0;
1359  SmallVector<ReassociationIndices> operandReassociation;
1360  auto origOpToCollapsedOpMapping =
1361  collapsingInfo.getOrigOpToCollapsedOpMapping();
1362  auto collapsedOpToOrigOpMapping =
1363  collapsingInfo.getCollapsedOpToOrigOpMapping();
1364  while (counter < indexingMap.getNumResults()) {
1365  unsigned dim =
1366  cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1367  // This is the start of a collapsed dimensions of the iteration that
1368  // is gauranteed to be preserved in the indexing map. The number of folded
1369  // dims is obtained from the collapsed op to original op mapping.
1370  unsigned numFoldedDims =
1371  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1372  .size();
1373  if (origOpToCollapsedOpMapping[dim].second == 0) {
1374  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1375  operandReassociation.emplace_back(range.begin(), range.end());
1376  }
1377  counter += numFoldedDims;
1378  }
1379  return operandReassociation;
1380 }
1381 
1382 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1383 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1384  OpOperand *opOperand,
1385  const CollapsingInfo &collapsingInfo,
1386  OpBuilder &builder) {
1387  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1388  SmallVector<ReassociationIndices> operandReassociation =
1389  getOperandReassociation(indexingMap, collapsingInfo);
1390 
1391  // If the number of entries in the reassociation for the operand is same as
1392  // the number of results of the indexing map, then nothing to do for this
1393  // operand.
1394  Value operand = opOperand->get();
1395  if (operandReassociation.size() == indexingMap.getNumResults())
1396  return operand;
1397 
1398  // Insert a reshape to collapse the dimensions.
1399  if (isa<MemRefType>(operand.getType())) {
1400  return builder
1401  .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1402  .getResult();
1403  }
1404  return builder
1405  .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1406  .getResult();
1407 }
1408 
1409 /// Modify the `linalg.index` operations in the original generic op, to its
1410 /// value in the collapsed operation.
1412  const CollapsingInfo &collapsingInfo,
1413  ValueRange loopRange,
1414  RewriterBase &rewriter) {
1415  OpBuilder::InsertionGuard g(rewriter);
1416  rewriter.setInsertionPointToStart(block);
1417 
1418  // Collect all the original index ops.
1419  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1420 
1421  // For each folded dimension list resolve the original induction variable
1422  // values in terms of the folded dimension induction variable.
1423  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1424  // can be inverted to
1425  // i2 = i_{folded} % d2
1426  // i1 = (i_{folded} / d2) % d1
1427  // i0 = i_{folded} / (d1 * d2)
1428  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1429  for (auto foldedDims :
1430  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1431  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1432  Value newIndexVal =
1433  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1434  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1435  indexReplacementVals[dim] =
1436  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1437  newIndexVal =
1438  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1439  }
1440  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1441  }
1442 
1443  for (auto indexOp : indexOps) {
1444  auto dim = indexOp.getDim();
1445  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1446  }
1447 }
1448 
1449 template <typename LinalgType>
1451  const CollapsingInfo &collapsingInfo,
1452  RewriterBase &rewriter) {
1453  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1454  "unsupported linalg op type to create");
1455  Location loc = op->getLoc();
1456 
1457  // Get the input operands.
1458  SmallVector<Value> inputOperands =
1459  llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1460  return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1461  rewriter);
1462  });
1463 
1464  // Get the output operands and result types.
1465  SmallVector<Type> resultTypes;
1466  SmallVector<Value> outputOperands;
1467  resultTypes.reserve(op.getNumDpsInits());
1468  outputOperands.reserve(op.getNumDpsInits());
1469  for (OpOperand &output : op.getDpsInitsMutable()) {
1470  Value newOutput =
1471  getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1472  outputOperands.push_back(newOutput);
1473  // If the op has "buffer semantics", then the init operands are ranked
1474  // memrefs and the op has no results.
1475  if (!op.hasPureBufferSemantics())
1476  resultTypes.push_back(newOutput.getType());
1477  }
1478 
1479  if (isa<linalg::CopyOp>(op)) {
1480  return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
1481  outputOperands[0]);
1482  }
1483 
1484  // Get the iterator types for the operand.
1485  SmallVector<utils::IteratorType> iteratorTypes =
1486  getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
1487 
1488  // Get the indexing maps.
1489  auto indexingMaps =
1490  llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
1491  return getCollapsedOpIndexingMap(map, collapsingInfo);
1492  });
1493 
1494  Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
1495  loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1496  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1497  Block *origOpBlock = &op->getRegion(0).front();
1498  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1499  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1500  collapsedOpBlock->getArguments());
1501 
1502  return collapsedOp;
1503 }
1504 
1505 /// Implementation of fusion with reshape operation by collapsing dimensions.
1506 template <typename LinalgType>
1508  LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
1509  RewriterBase &rewriter) {
1510  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1511  "unsupported linalg op type to collapse");
1512 
1513  // Bail on trivial no-op cases.
1514  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1515  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1516  return foldedDims.size() <= 1;
1517  }))
1518  return failure();
1519 
1520  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1521  if (hasPureBufferSemantics &&
1522  !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1523  MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1524  if (!memRefToCollapse)
1525  return true;
1526 
1527  return memref::CollapseShapeOp::isGuaranteedCollapsible(
1528  memRefToCollapse, foldedIterationDims);
1529  }))
1530  return rewriter.notifyMatchFailure(op,
1531  "memref is not guaranteed collapsible");
1532 
1533  CollapsingInfo collapsingInfo;
1534  if (failed(
1535  collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1536  return rewriter.notifyMatchFailure(
1537  op, "illegal to collapse specified dimensions");
1538  }
1539 
1540  // Bail on non-canonical ranges.
1541  SmallVector<Range> loopRanges =
1542  cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
1543  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1544  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1545  return cast<IntegerAttr>(attr).getInt() == value;
1546  llvm::APInt actual;
1547  return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1548  actual.getSExtValue() == value;
1549  };
1550  if (!llvm::all_of(loopRanges, [&](Range range) {
1551  return opFoldIsConstantValue(range.offset, 0) &&
1552  opFoldIsConstantValue(range.stride, 1);
1553  })) {
1554  return rewriter.notifyMatchFailure(
1555  op, "expected all loop ranges to have zero start and unit stride");
1556  }
1557 
1558  LinalgType collapsedOp = cast<LinalgType>(
1559  createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
1560 
1561  Location loc = op->getLoc();
1562  if (collapsedOp.hasIndexSemantics()) {
1563  // Collect the loop range of the generic op.
1564  OpBuilder::InsertionGuard g(rewriter);
1565  rewriter.setInsertionPoint(collapsedOp);
1566  SmallVector<Value> loopBound =
1567  llvm::map_to_vector(loopRanges, [&](Range range) {
1568  return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1569  });
1570  generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1571  collapsingInfo, loopBound, rewriter);
1572  }
1573 
1574  // Insert expanding reshape for the result to get back the original result
1575  // type.
1576  SmallVector<Value> results;
1577  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1578  Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1579  auto originalResultType =
1580  cast<ShapedType>(originalResult.value().getType());
1581  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1582  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1583  AffineMap indexingMap =
1584  op.getIndexingMapMatchingResult(originalResult.value());
1585  SmallVector<ReassociationIndices> reassociation =
1586  getOperandReassociation(indexingMap, collapsingInfo);
1587  if (isa<MemRefType>(collapsedOpResult.getType())) {
1588  Value result = rewriter.create<memref::ExpandShapeOp>(
1589  loc, originalResultType, collapsedOpResult, reassociation);
1590  results.push_back(result);
1591  } else {
1592  Value result = rewriter.create<tensor::ExpandShapeOp>(
1593  loc, originalResultType, collapsedOpResult, reassociation);
1594  results.push_back(result);
1595  }
1596  } else {
1597  results.push_back(collapsedOpResult);
1598  }
1599  }
1600  return results;
1601 }
1602 
1603 namespace {
1604 
1605 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1606 /// contracting dimensions of the loop.
1607 class FoldWithProducerReshapeOpByCollapsing
1608  : public OpRewritePattern<GenericOp> {
1609 public:
1610  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1611  ControlFusionFn foldReshapes,
1612  PatternBenefit benefit = 1)
1613  : OpRewritePattern<GenericOp>(context, benefit),
1614  controlFoldingReshapes(std::move(foldReshapes)) {}
1615 
1616  LogicalResult matchAndRewrite(GenericOp genericOp,
1617  PatternRewriter &rewriter) const override {
1618  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1619  tensor::ExpandShapeOp reshapeOp =
1620  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1621  if (!reshapeOp)
1622  continue;
1623 
1624  SmallVector<ReassociationIndices> collapsableIterationDims =
1625  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1626  reshapeOp.getReassociationIndices());
1627  if (collapsableIterationDims.empty() ||
1628  !controlFoldingReshapes(&opOperand)) {
1629  continue;
1630  }
1631 
1632  std::optional<SmallVector<Value>> replacements =
1633  collapseOpIterationDims<linalg::GenericOp>(
1634  genericOp, collapsableIterationDims, rewriter);
1635  if (!replacements) {
1636  return rewriter.notifyMatchFailure(
1637  genericOp, "failed to do the fusion by collapsing transformation");
1638  }
1639 
1640  rewriter.replaceOp(genericOp, *replacements);
1641  return success();
1642  }
1643  return failure();
1644  }
1645 
1646 private:
1647  ControlFusionFn controlFoldingReshapes;
1648 };
1649 
1650 /// Pattern to collapse dimensions.
1651 template <typename LinalgType>
1652 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1653 public:
1654  CollapseLinalgDimensions(MLIRContext *context,
1655  GetCollapsableDimensionsFn collapseDimensions,
1656  PatternBenefit benefit = 1)
1657  : OpRewritePattern<LinalgType>(context, benefit),
1658  controlCollapseDimension(std::move(collapseDimensions)) {}
1659 
1660  LogicalResult matchAndRewrite(LinalgType op,
1661  PatternRewriter &rewriter) const override {
1662  SmallVector<ReassociationIndices> collapsableIterationDims =
1663  controlCollapseDimension(op);
1664  if (collapsableIterationDims.empty())
1665  return failure();
1666 
1667  // Check if the specified list of dimensions to collapse is a valid list.
1668  if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1669  collapsableIterationDims)) {
1670  return rewriter.notifyMatchFailure(
1671  op, "specified dimensions cannot be collapsed");
1672  }
1673 
1674  std::optional<SmallVector<Value>> replacements =
1675  collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
1676  rewriter);
1677  if (!replacements) {
1678  return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
1679  }
1680  rewriter.replaceOp(op, *replacements);
1681  return success();
1682  }
1683 
1684 private:
1685  GetCollapsableDimensionsFn controlCollapseDimension;
1686 };
1687 
1688 } // namespace
1689 
1690 //===---------------------------------------------------------------------===//
1691 // Methods and patterns that fuse constants with linalg.generic operations.
1692 //===---------------------------------------------------------------------===//
1693 
1694 namespace {
1695 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1696 /// handle cases where the constant is not single-valued.
1697 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1698 public:
1699  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1700  : OpRewritePattern<GenericOp>(context, benefit) {}
1701 
1702  LogicalResult matchAndRewrite(GenericOp genericOp,
1703  PatternRewriter &rewriter) const override {
1704  if (!genericOp.hasPureTensorSemantics())
1705  return failure();
1706  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1707  Operation *def = opOperand->get().getDefiningOp();
1708  TypedAttr constantAttr;
1709  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1710  {
1711  DenseElementsAttr splatAttr;
1712  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1713  splatAttr.isSplat() &&
1714  splatAttr.getType().getElementType().isIntOrFloat()) {
1715  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1716  return true;
1717  }
1718  }
1719  {
1720  IntegerAttr intAttr;
1721  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1722  constantAttr = intAttr;
1723  return true;
1724  }
1725  }
1726  {
1727  FloatAttr floatAttr;
1728  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1729  constantAttr = floatAttr;
1730  return true;
1731  }
1732  }
1733  return false;
1734  };
1735 
1736  auto resultValue = dyn_cast<OpResult>(opOperand->get());
1737  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1738  continue;
1739 
1740  // The operands and the indexing_maps of the fused operation the same as
1741  // the operands and indexing_maps of the generic operations with the
1742  // values at the constant index dropped.
1743  SmallVector<AffineMap> fusedIndexMaps;
1744  SmallVector<Value> fusedOperands;
1745  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1746  fusedIndexMaps.reserve(genericOp->getNumOperands());
1747  fusedOperands.reserve(genericOp.getNumDpsInputs());
1748  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1749  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1750  if (inputOperand == opOperand)
1751  continue;
1752  Value inputValue = inputOperand->get();
1753  fusedIndexMaps.push_back(
1754  genericOp.getMatchingIndexingMap(inputOperand));
1755  fusedOperands.push_back(inputValue);
1756  fusedLocs.push_back(inputValue.getLoc());
1757  }
1758  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1759  fusedIndexMaps.push_back(
1760  genericOp.getMatchingIndexingMap(&outputOperand));
1761 
1762  // Check if the operation shapes to loops map is computable.
1763  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1764  return rewriter.notifyMatchFailure(
1765  genericOp, "fused op loop bound computation failed");
1766  }
1767 
1768  // Create a constant scalar value from the splat constant.
1769  Value scalarConstant =
1770  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
1771 
1772  SmallVector<Value> outputOperands = genericOp.getOutputs();
1773  auto fusedOp = rewriter.create<GenericOp>(
1774  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1775  /*inputs=*/fusedOperands,
1776  /*outputs=*/outputOperands,
1777  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1778  genericOp.getIteratorTypes(),
1779  /*doc=*/nullptr,
1780  /*library_call=*/nullptr);
1781 
1782  // Map the block argument corresponding to the replaced argument with the
1783  // scalar constant.
1784  Region &region = genericOp->getRegion(0);
1785  Block &entryBlock = *region.begin();
1786  IRMapping mapping;
1787  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1788  scalarConstant);
1789  Region &fusedRegion = fusedOp->getRegion(0);
1790  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1791  mapping);
1792  rewriter.replaceOp(genericOp, fusedOp->getResults());
1793  return success();
1794  }
1795  return failure();
1796  }
1797 };
1798 
1799 } // namespace
1800 
1801 //===---------------------------------------------------------------------===//
1802 // Miscellaneous patterns that help fusion.
1803 //===---------------------------------------------------------------------===//
1804 
1805 namespace {
1806 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
1807 /// value of the `outs` operand is not used within the op. This is only
1808 /// implemented for `linalg.generic` operations for now, but should hold for all
1809 /// linalg structured ops.
1810 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1812 
1813  LogicalResult matchAndRewrite(GenericOp op,
1814  PatternRewriter &rewriter) const override {
1815  rewriter.startOpModification(op);
1816  bool modifiedOutput = false;
1817  Location loc = op.getLoc();
1818  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1819  if (!op.payloadUsesValueFromOperand(&opOperand)) {
1820  Value operandVal = opOperand.get();
1821  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
1822  if (!operandType)
1823  continue;
1824 
1825  // If outs is sparse, leave it to the sparsifier.
1827  continue;
1828 
1829  // If outs is already an `empty` operation, nothing to do.
1830  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
1831  if (definingOp)
1832  continue;
1833  modifiedOutput = true;
1834  SmallVector<OpFoldResult> mixedSizes =
1835  tensor::getMixedSizes(rewriter, loc, operandVal);
1836  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1837  loc, mixedSizes, operandType.getElementType());
1838  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
1839  }
1840  }
1841  if (!modifiedOutput) {
1842  rewriter.cancelOpModification(op);
1843  return failure();
1844  }
1845  rewriter.finalizeOpModification(op);
1846  return success();
1847  }
1848 };
1849 
1850 /// Fold linalg.fill into linalg.generic
1851 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1853 
1854  LogicalResult matchAndRewrite(GenericOp genericOp,
1855  PatternRewriter &rewriter) const override {
1856  if (!genericOp.hasPureTensorSemantics())
1857  return failure();
1858  bool fillFound = false;
1859  Block &payload = genericOp.getRegion().front();
1860  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1861  if (!genericOp.payloadUsesValueFromOperand(opOperand))
1862  continue;
1863  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1864  if (!fillOp)
1865  continue;
1866  fillFound = true;
1867  Value fillVal = fillOp.value();
1868  auto resultType =
1869  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
1870  Value convertedVal =
1871  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
1872  /*isUnsignedCast =*/false);
1873  rewriter.replaceAllUsesWith(
1874  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
1875  }
1876  return success(fillFound);
1877  }
1878 };
1879 } // namespace
1880 
1882  RewritePatternSet &patterns,
1883  const ControlFusionFn &controlFoldingReshapes) {
1884  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1885  controlFoldingReshapes);
1886  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1887  controlFoldingReshapes);
1888 }
1889 
1891  RewritePatternSet &patterns,
1892  const ControlFusionFn &controlFoldingReshapes) {
1893  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1894  controlFoldingReshapes);
1895 }
1896 
1898  RewritePatternSet &patterns,
1899  const ControlFusionFn &controlElementwiseOpsFusion) {
1900  auto *context = patterns.getContext();
1901  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1902  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1903  RemoveOutsDependency>(context);
1904  // Add the patterns that clean up dead operands and results.
1906 }
1907 
1909  RewritePatternSet &patterns,
1910  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
1911  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
1912  CollapseLinalgDimensions<linalg::CopyOp>>(
1913  patterns.getContext(), controlCollapseDimensions);
1914 }
1915 
1916 //===---------------------------------------------------------------------===//
1917 // Passes
1918 //===---------------------------------------------------------------------===//
1919 
1920 namespace {
1921 
1922 /// Pass that fuses generic ops on tensors. Used only for testing.
1923 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1924 // patterns added here heavily depends on the cost function used. Having an
1925 // opinionated pass of this form is not recommended. Deprecate this pass in
1926 // favor of test passes that check the functionality of each of the patterns
1927 // added here individually.
1928 struct LinalgElementwiseOpFusionPass
1929  : public impl::LinalgElementwiseOpFusionPassBase<
1930  LinalgElementwiseOpFusionPass> {
1931  using impl::LinalgElementwiseOpFusionPassBase<
1932  LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
1933  void runOnOperation() override {
1934  Operation *op = getOperation();
1935  MLIRContext *context = op->getContext();
1936  RewritePatternSet patterns(context);
1937 
1938  // Add folding with reshape by expansion patterns.
1939  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
1940  Operation *producer = fusedOperand->get().getDefiningOp();
1941  return producer && producer->hasOneUse();
1942  };
1943 
1944  // Add elementwise op fusion patterns.
1945  populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1946  populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1947 
1948  // General canonicalization patterns.
1949  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1950  GenericOp::getCanonicalizationPatterns(patterns, context);
1951  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1952  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1953  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1954  patterns);
1955 
1956  // Add constant folding patterns.
1957  populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1958 
1959  // Use TopDownTraversal for compile time reasons
1960  GreedyRewriteConfig grc;
1961  grc.useTopDownTraversal = true;
1962  (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
1963  }
1964 };
1965 
1966 } // namespace
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
Operation * createCollapsedOp(LinalgType op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:401
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:581
unsigned getNumSymbols() const
Definition: AffineMap.cpp:384
unsigned getNumDims() const
Definition: AffineMap.cpp:380
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:393
unsigned getNumResults() const
Definition: AffineMap.cpp:388
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:397
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:617
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:542
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:611
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator begin()
Definition: Block.h:140
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:190
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:580
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:640
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:625
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:615
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
FailureOr< SmallVector< Value > > collapseOpIterationDims(LinalgType op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1523
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1554
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Definition: SparseTensor.h:100
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:93
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:169
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:755
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:800
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:493
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:495
static llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer)
Returns a set of indices of the producer's results which would be preserved after the fusion.