MLIR  18.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
26 #include <optional>
27 #include <utility>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
31 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSION
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 //===---------------------------------------------------------------------===//
39 // Methods and patterns that fuse elementwise `linalg.generic` operations.
40 //===---------------------------------------------------------------------===//
41 
42 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
43 /// the `producer` to use in the fused operation given the indexing map of the
44 /// result of the producer in the consumer.
46  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
47  AffineMap fusedConsumerArgIndexMap) {
48  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
49  // from consumer loop -> consumer arg tensor index/producer result tensor
50  // index. The fused loop is same as the consumer loop. For each producer arg
51  // the indexing map to be computed is a map from consumer loop -> producer
52  // arg tensor index.
53  // producerResultIndexMap is a map from producer loop -> tensor index.
54  // Compute the inverse to get map from tensor index -> producer loop.
55  // The inverse is a map from producer result tensor index -> producer loop.
56  AffineMap invProducerResultIndexMap =
57  inversePermutation(producerResultIndexMap);
58  assert(invProducerResultIndexMap &&
59  "expected producer result indexing map to be invertible");
60 
61  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
62  // argMap is a map from producer loop -> producer arg tensor index.
63  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
64 
65  // Compose argMap with invProducerResultIndexMap to get a map from
66  // producer result tensor index -> producer arg tensor index.
67  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
68 
69  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
70  // consumer loop/ fused loop -> producer arg tensor index.
71  return t1.compose(fusedConsumerArgIndexMap);
72 }
73 
74 /// Conditions for elementwise fusion of generic operations.
76  if (!fusedOperand)
77  return false;
78 
79  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
80  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
81 
82  // Check producer and consumer are generic ops.
83  if (!producer || !consumer)
84  return false;
85 
86  // Consumer can have mixed semantics, just check operand itself has tensor
87  // type. Producer must have full tensor semantics to avoid potential
88  // aliasing between producer and consumer memrefs.
89  if (!producer.hasTensorSemantics() ||
90  !isa<RankedTensorType>(fusedOperand->get().getType()))
91  return false;
92 
93  // Verify that
94  // - the producer has all "parallel" iterator type.
95  if (producer.getNumParallelLoops() != producer.getNumLoops())
96  return false;
97 
98  // Only allow fusing the producer of an input operand for now.
99  // TODO: allow fusing the producer of an output operand.
100  if (!consumer.isDpsInput(fusedOperand))
101  return false;
102 
103  // Get the consumer index map. The number of results of the consumer index
104  // map must match the number of loops of the producer.
105  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
106  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
107  return false;
108 
109  // Finally the index_map for the result must be invertible. For now just
110  // verify it is a permutation.
111  AffineMap producerResultIndexMap =
112  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
113  if (!producerResultIndexMap.isPermutation())
114  return false;
115 
116  // Ensure that the fusion does not remove size information required to
117  // get the loop bounds. For non-reduction generics, this is trivially the
118  // case due to the output operand. For reductions, we need to check that after
119  // the fusion, each loop dimension has at least one input that defines it.
120  if ((consumer.getNumReductionLoops())) {
121  BitVector coveredDims(consumer.getNumLoops(), false);
122 
123  auto addToCoveredDims = [&](AffineMap map) {
124  for (auto result : map.getResults())
125  if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
126  coveredDims[dimExpr.getPosition()] = true;
127  };
128 
129  for (auto pair :
130  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
131  Value operand = std::get<0>(pair);
132  if (operand == fusedOperand->get())
133  continue;
134  AffineMap operandMap = std::get<1>(pair);
135  addToCoveredDims(operandMap);
136  }
137 
138  for (OpOperand *operand : producer.getDpsInputOperands()) {
139  AffineMap newIndexingMap =
141  operand, producerResultIndexMap, consumerIndexMap);
142  addToCoveredDims(newIndexingMap);
143  }
144  if (!coveredDims.all())
145  return false;
146  }
147 
148  return true;
149 }
150 
151 /// Generate the region of the fused tensor operation. The region of the fused
152 /// op must be empty.
154  RewriterBase &rewriter, GenericOp fusedOp,
155  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
156  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
157  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
158  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
159  // Build the region of the fused op.
160  Block &producerBlock = producer->getRegion(0).front();
161  Block &consumerBlock = consumer->getRegion(0).front();
162  Block *fusedBlock = new Block();
163  fusedOp.getRegion().push_back(fusedBlock);
164  IRMapping mapper;
165  OpBuilder::InsertionGuard guard(rewriter);
166  rewriter.setInsertionPointToStart(fusedBlock);
167 
168  // 2. Add an index operation for every fused loop dimension and use the
169  // `consumerToProducerLoopsMap` to map the producer indices.
170  if (producer.hasIndexSemantics()) {
171  // Add an index operation for every fused loop dimension.
172  unsigned numFusedOpLoops =
173  std::max(producer.getNumLoops(), consumer.getNumLoops());
174  SmallVector<Value> fusedIndices;
175  fusedIndices.reserve(numFusedOpLoops);
176  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
177  std::back_inserter(fusedIndices), [&](uint64_t dim) {
178  return rewriter.create<IndexOp>(producer.getLoc(), dim);
179  });
180  for (IndexOp indexOp :
181  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
182  Value newIndex = rewriter.create<affine::AffineApplyOp>(
183  producer.getLoc(),
184  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
185  mapper.map(indexOp.getResult(), newIndex);
186  }
187  }
188  // TODO: allow fusing the producer of an output operand.
189  assert(consumer.isDpsInput(fusedOperand) &&
190  "expected producer of input operand");
191  // 3. Consumer input operands up to consumerIdx (exclusive).
192  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
193  fusedOperand->getOperandNumber())) // input assumption.
194  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
195 
196  // Replacing consumerIdx requires getting the cloned, yielded, value from
197  // the (cloned) producer block. This happens in step 9.
198 
199  // 4. Splice in producer's input operands.
200  for (BlockArgument bbArg :
201  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
202  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
203 
204  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
205  for (BlockArgument bbArg :
206  consumerBlock.getArguments()
207  .take_front(consumer.getNumDpsInputs())
208  .drop_front(fusedOperand->getOperandNumber() + 1))
209  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210 
211  // 6. All of the producer's output operands
212  for (const auto &bbArg : llvm::enumerate(
213  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
214  if (!preservedProducerResults.count(bbArg.index()))
215  continue;
216  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
217  bbArg.value().getLoc()));
218  }
219 
220  // 7. All of consumer's output operands.
221  for (BlockArgument bbArg :
222  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
223  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
224 
225  // 8. Clone all producer operations except for the yield and index operations
226  // to the fused operation.
227  for (auto &op : producerBlock.without_terminator()) {
228  if (!isa<IndexOp>(op))
229  rewriter.clone(op, mapper);
230  }
231  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
232  // forward the yield operand.
233  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
234  unsigned producerResultNumber =
235  cast<OpResult>(fusedOperand->get()).getResultNumber();
236  Value replacement =
237  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
238 
239  // Sanity checks, if replacement is not already in the mapper then it must be
240  // produced outside.
241  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
242  if (auto bb = dyn_cast<BlockArgument>(replacement))
243  assert(bb.getOwner() != &producerBlock &&
244  "yielded block argument must have been mapped");
245  else
246  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
247  "yielded value must have been mapped");
248  }
249  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
250  replacement);
251  // 10. Clone operations from the consumer to the fused op.
252  for (auto &op : consumerBlock.without_terminator())
253  rewriter.clone(op, mapper);
254 
255  // 11. Include the final yield (which is the remapped values for all the
256  // yield)
257  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
258  SmallVector<Value> fusedYieldValues;
259  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
260  consumerYieldOp.getNumOperands());
261  for (const auto &producerYieldVal :
262  llvm::enumerate(producerYieldOp.getOperands())) {
263  if (preservedProducerResults.count(producerYieldVal.index()))
264  fusedYieldValues.push_back(
265  mapper.lookupOrDefault(producerYieldVal.value()));
266  }
267  for (auto consumerYieldVal : consumerYieldOp.getOperands())
268  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
269  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
270 
271  // Sanity checks.
272  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
273  "Ill-formed GenericOp region");
274 }
275 
278  OpOperand *fusedOperand) {
279  assert(areElementwiseOpsFusable(fusedOperand) &&
280  "expected elementwise operation pre-conditions to pass");
281  auto producerResult = cast<OpResult>(fusedOperand->get());
282  auto producer = cast<GenericOp>(producerResult.getOwner());
283  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
284  // TODO: allow fusing the producer of an output operand.
285  assert(consumer.isDpsInput(fusedOperand) &&
286  "expected producer of input operand");
287  /// Find the results of the producer that have uses outside of the consumer.
288  llvm::SmallDenseSet<int> preservedProducerResults;
289  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
290  auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
291  if (producer.payloadUsesValueFromOperand(outputOperand) ||
292  !producer.canOpOperandsBeDropped(outputOperand) ||
293  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
294  return user != consumer.getOperation();
295  })) {
296  preservedProducerResults.insert(producerResult.index());
297  }
298  }
299 
300  // Compute the fused operands list and indexing maps.
301  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
302  SmallVector<Type> fusedResultTypes;
303  SmallVector<AffineMap> fusedIndexMaps;
304  fusedInputOperands.reserve(producer.getNumDpsInputs() +
305  consumer.getNumDpsInputs());
306  fusedOutputOperands.reserve(preservedProducerResults.size() +
307  consumer.getNumDpsInits());
308  fusedResultTypes.reserve(preservedProducerResults.size() +
309  consumer.getNumDpsInits());
310  fusedIndexMaps.reserve(producer->getNumOperands() +
311  consumer->getNumOperands());
312  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
313  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
314  auto consumerInputs = consumer.getDpsInputOperands();
315  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
316  return operand == fusedOperand;
317  });
318  assert(it != consumerInputs.end() && "expected to find the consumer operand");
319  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
320  fusedInputOperands.push_back(opOperand->get());
321  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
322  }
323  // 4. Splice in producer's input operands/maps.
324  AffineMap producerResultIndexMap =
325  producer.getIndexingMapMatchingResult(producerResult);
326  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
327  fusedInputOperands.push_back(opOperand->get());
328  // Compute indexing maps for the producer args in the fused operation.
330  opOperand, producerResultIndexMap,
331  consumer.getMatchingIndexingMap(fusedOperand));
332  fusedIndexMaps.push_back(map);
333  }
334  // 5. Remaining consumer's input operands/maps (drop past index
335  // `consumerIdx`).
336  for (OpOperand *opOperand :
337  llvm::make_range(std::next(it), consumerInputs.end())) {
338  fusedInputOperands.push_back(opOperand->get());
339  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
340  }
341 
342  // 6. Collect all of the producer outputs.
343  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
344  if (!preservedProducerResults.count(opOperand.index()))
345  continue;
346 
347  fusedOutputOperands.push_back(opOperand.value().get());
349  &opOperand.value(), producerResultIndexMap,
350  consumer.getMatchingIndexingMap(fusedOperand));
351  fusedIndexMaps.push_back(map);
352  fusedResultTypes.push_back(opOperand.value().get().getType());
353  }
354 
355  // 7. All of consumer's output operands (skip operands: added by the builder).
356  for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
357  fusedOutputOperands.push_back(opOperand.get());
358  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
359  Type resultType = opOperand.get().getType();
360  if (!isa<MemRefType>(resultType))
361  fusedResultTypes.push_back(resultType);
362  }
363 
364  // Generate the fused op.
365  auto fusedOp = rewriter.create<GenericOp>(
366  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
367  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
368  consumer.getIteratorTypes(),
369  /*doc=*/nullptr,
370  /*library_call=*/nullptr);
371  if (!fusedOp.getShapesToLoopsMap()) {
372  // Fused op has invalid indexing maps. Typically this means something is off
373  // in the input, but going ahead here would result in verification errors.
374  // So cleanup and abort.
375  rewriter.eraseOp(fusedOp);
376  return rewriter.notifyMatchFailure(
377  fusedOp, "fused op failed loop bound computation check");
378  }
379 
380  // Construct an AffineMap from consumer loops to producer loops.
381  // consumer loop -> tensor index
382  AffineMap consumerResultIndexMap =
383  consumer.getMatchingIndexingMap(fusedOperand);
384  // tensor index -> producer loop
385  AffineMap invProducerResultIndexMap =
386  inversePermutation(producerResultIndexMap);
387  assert(invProducerResultIndexMap &&
388  "expected producer result indexig map to be invertible");
389  // consumer loop -> producer loop
390  AffineMap consumerToProducerLoopsMap =
391  invProducerResultIndexMap.compose(consumerResultIndexMap);
392 
394  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
395  consumer.getNumLoops(), preservedProducerResults);
397  result.fusedOp = fusedOp;
398  int resultNum = 0;
399  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
400  if (preservedProducerResults.count(index))
401  result.replacements[producerResult] = fusedOp->getResult(resultNum++);
402  for (auto consumerResult : consumer->getResults())
403  result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
404  return result;
405 }
406 
407 namespace {
408 /// Patterns to fuse a generic op, with the producer of its operands.
409 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
410 public:
411  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
412  PatternBenefit benefit = 1)
413  : OpRewritePattern<GenericOp>(context, benefit),
414  controlFn(std::move(fun)) {}
415 
416  LogicalResult matchAndRewrite(GenericOp genericOp,
417  PatternRewriter &rewriter) const override {
418  // Find the first operand that is defined by another generic op on tensors.
419  for (OpOperand &opOperand : genericOp->getOpOperands()) {
420  if (!areElementwiseOpsFusable(&opOperand))
421  continue;
422  if (!controlFn(&opOperand))
423  continue;
424 
425  // Find the producer of the operand.
427  fuseElementwiseOps(rewriter, &opOperand);
428  if (failed(fusionResult))
429  return rewriter.notifyMatchFailure(genericOp, "fusion failed");
430  Operation *producer = opOperand.get().getDefiningOp();
431 
432  // Do not fuse a sparse-in/dense-out operation, as the
433  // result is too often not sparsifiable anymore.
434  if (sparse_tensor::hasAnySparseOperand(producer) &&
436  return failure();
437 
438  // Perform the fusion.
439  for (auto [origVal, replacement] : fusionResult->replacements) {
440  rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
441  // Only replace consumer uses.
442  return use.get().getDefiningOp() != producer;
443  });
444  }
445  rewriter.eraseOp(genericOp);
446  return success();
447  }
448  return failure();
449  }
450 
451 private:
452  ControlFusionFn controlFn;
453 };
454 } // namespace
455 
456 //===---------------------------------------------------------------------===//
457 // Methods and patterns that fuse reshape ops with elementwise operations by
458 // expanding the dimensionality of the elementwise operations.
459 //===---------------------------------------------------------------------===//
460 
461 /// Conditions for folding a generic operation with a reshape op by expanding
462 /// the iteration space dimensionality for tensor operations. These are
463 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
464 /// following fusion pattern.
465 ///
466 /// Consider
467 ///
468 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
469 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
470 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
471 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
472 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
473 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
474 ///
475 /// The reshape can be folded into the `genericOp` if its loop dimensionality
476 /// is increased to match the result (operand) of the tensor.expand_shape.
477 /// The indexing_map of the fused tensor in the `genericOp` and the
478 /// reassociation map helps compute the indexing maps of the modified op.
479 /// For the above example, based on the reassociation map it
480 /// can be concluded that
481 ///
482 /// - The loop used to access the first dimension of the fused tensor is split
483 /// into two.
484 /// - The loop used to access the second dimension of the fused tensor is kept
485 /// as is.
486 /// - The loop used to access the third dimension of the fused tensor is split
487 /// into three.
488 ///
489 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
490 /// op, then
491 ///
492 /// d0 -> e0, e1
493 /// d1 -> e2, e3, e4
494 /// d2 -> e5
495 ///
496 /// substituting this, the generic op can be rewritten as
497 ///
498 /// %d = linalg.generic ins(%0, %1 : )
499 /// indexing_maps =
500 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
501 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
502 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
503 ///
504 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
505 /// to make it consistent
506 ///
507 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
508 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
509 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
510 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
511 ///
512 /// The added reshapes are again expanding patterns, so they will get fused
513 /// with its producers if possible.
514 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
515  OpOperand *fusableOpOperand) {
516  // Is fusable only if:
517  // - All the indexing maps for operands and results are projected
518  // permutations.
519  // - The fused tensor is not a scalar.
520  // - All the loops are parallel loops.
521  return genericOp.hasTensorSemantics() &&
522  llvm::all_of(genericOp.getIndexingMaps().getValue(),
523  [](Attribute attr) {
524  return cast<AffineMapAttr>(attr)
525  .getValue()
526  .isProjectedPermutation();
527  }) &&
528  genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
529  0 &&
530  llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
531 }
532 
533 namespace {
534 /// Information needed to expand a generic operation to fold the reshape with
535 /// it.
536 class ExpansionInfo {
537 public:
538  // Computes the mapping from original dimensions of the op to the dimensions
539  // of the expanded op given the `indexingMap` of the fused operand/result of
540  // the generic op, the `reassocationMaps` of the reshape op and the shape of
541  // the expanded op.
542  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
543  ArrayRef<AffineMap> reassociationMaps,
544  ArrayRef<int64_t> expandedShape,
545  ArrayRef<int64_t> collapsedShape,
546  PatternRewriter &rewriter);
547  unsigned getOrigOpNumDims() const { return reassociation.size(); }
548  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
549  ReassociationIndicesRef getExpandedDims(unsigned i) const {
550  return reassociation[i];
551  }
552  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
553  return expandedShapeMap[i];
554  }
555  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
556 
557 private:
558  /// Reassociation from the dimensions in the original operation to the
559  /// dimension of the expanded operation.
560  SmallVector<ReassociationIndices> reassociation;
561  /// Mapping from extent of loops in the original operation, to the extent of
562  /// loops in the expanded operation.
563  SmallVector<SmallVector<int64_t>> expandedShapeMap;
564  /// Extent of the loop in the original operation.
565  SmallVector<int64_t> originalLoopExtent;
566  unsigned expandedOpNumDims;
567 };
568 } // namespace
569 
570 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
571  OpOperand *fusableOpOperand,
572  ArrayRef<AffineMap> reassociationMaps,
573  ArrayRef<int64_t> expandedShape,
574  ArrayRef<int64_t> collapsedShape,
575  PatternRewriter &rewriter) {
576  if (reassociationMaps.empty())
577  return failure();
578  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
579 
580  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
581  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
582 
583  reassociation.clear();
584  expandedShapeMap.clear();
585  // Compute the number of dimension in the expanded op that correspond to each
586  // dimension of the original op.
587  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
588  expandedShapeMap.resize(fusedIndexMap.getNumDims());
589  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
590  unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
591  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
592  numExpandedDims[pos] = foldedDims.getNumResults();
593  ArrayRef<int64_t> shape =
594  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
595  expandedShapeMap[pos].assign(shape.begin(), shape.end());
596  }
597  // The remaining dimensions remain the same.
598  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
599  if (expandedShapeMap[i].empty())
600  expandedShapeMap[i] = {originalLoopExtent[i]};
601 
602  // Compute reassociation map from the original op to the expanded op.
603  unsigned sum = 0;
604  reassociation.reserve(fusedIndexMap.getNumDims());
605  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
606  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
607  reassociation.emplace_back(seq.begin(), seq.end());
608  sum += numFoldedDim.value();
609  }
610  expandedOpNumDims = sum;
611  return success();
612 }
613 
614 /// Epanding the body of a linalg operation requires adaptations of the accessed
615 /// loop indices. Specifically, access of indices in the original operation need
616 /// to be replaced with linearizations of indices in the expanded op. That
617 /// requires the shape of the expanded dimensions to be static (at least all but
618 /// the most significant). For now check that these are all statically sized.
619 /// Note that this could be extended to handle dynamic case, but the
620 /// implementation below uses `affine.apply` which seems to have issues when the
621 /// shapes are not static.
622 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
623  const ExpansionInfo &expansionInfo,
624  PatternRewriter &rewriter) {
625  if (!genericOp.hasIndexSemantics())
626  return success();
627  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
628  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
629  if (expandedShape.size() == 1)
630  continue;
631  for (int64_t shape : expandedShape.drop_front()) {
632  if (ShapedType::isDynamic(shape)) {
633  return rewriter.notifyMatchFailure(
634  genericOp, "cannot expand due to index semantics and dynamic dims");
635  }
636  }
637  }
638  return success();
639 }
640 
641 /// Return the indexing map to use in the expanded op for a given the
642 /// `indexingMap` of the original operation.
643 static AffineMap
645  const ExpansionInfo &expansionInfo) {
646  SmallVector<AffineExpr> newExprs;
647  for (AffineExpr expr : indexingMap.getResults()) {
648  unsigned pos = expr.cast<AffineDimExpr>().getPosition();
649  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
650  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
651  return builder.getAffineDimExpr(static_cast<unsigned>(v));
652  }));
653  newExprs.append(expandedExprs.begin(), expandedExprs.end());
654  }
655  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
656  indexingMap.getNumSymbols(), newExprs,
657  builder.getContext());
658 }
659 
660 /// Return the type of the operand/result to use in the expanded op given the
661 /// type in the original op.
662 static RankedTensorType getExpandedType(RankedTensorType originalType,
663  AffineMap indexingMap,
664  const ExpansionInfo &expansionInfo) {
665  SmallVector<int64_t> expandedShape;
666  for (AffineExpr expr : indexingMap.getResults()) {
667  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
668  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
669  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
670  }
671  return RankedTensorType::get(expandedShape, originalType.getElementType());
672 }
673 
674 /// Returns the reassociation maps to use in the `tensor.expand_shape`
675 /// operation to convert the operands of the original operation to operands of
676 /// the expanded operation. The same method is used to compute the
677 /// `tensor.collapse_shape` used to collapse the result of the expanded
678 /// op to get the value that can replace all uses of the results of the original
679 /// op.
682  const ExpansionInfo &expansionInfo) {
683  SmallVector<ReassociationIndices> reassociation;
684  unsigned numReshapeDims = 0;
685  for (AffineExpr expr : indexingMap.getResults()) {
686  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
687  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
688  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
689  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
690  reassociation.emplace_back(std::move(indices));
691  numReshapeDims += numExpandedDims;
692  }
693  return reassociation;
694 }
695 
696 /// Update the body of an expanded linalg operation having index semantics. The
697 /// indices of the original operation need to be recovered by linearizing the
698 /// indices of the correspoding dimensions of the expanded operation. For now it
699 /// is assumed that the shapes of the expanded operation needed for
700 /// linearization are static.
702  Location loc, Region &fusedRegion,
703  const ExpansionInfo &expansionInfo) {
704  // Replace the original indices by the linearization of the expanded indices.
705  for (IndexOp indexOp :
706  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
707  ArrayRef<int64_t> expandedDims =
708  expansionInfo.getExpandedDims(indexOp.getDim());
709  assert(!expandedDims.empty() && "expected valid expansion info");
710 
711  // Skip index operations that are not affected by the expansion.
712  if (expandedDims.size() == 1 &&
713  expandedDims.front() == (int64_t)indexOp.getDim())
714  continue;
715 
716  // Linearize the expanded indices of the original index dimension.
717  OpBuilder::InsertionGuard guard(rewriter);
718  rewriter.setInsertionPointAfter(indexOp);
719  ArrayRef<int64_t> expandedDimsShape =
720  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
721  SmallVector<Value> expandedIndices;
722  expandedIndices.reserve(expandedDims.size() - 1);
723  llvm::transform(
724  expandedDims.drop_front(), std::back_inserter(expandedIndices),
725  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
726  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
727  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
728  assert(!ShapedType::isDynamic(std::get<0>(it)));
729  AffineExpr idx, acc;
730  bindDims(rewriter.getContext(), idx, acc);
731  newIndex = rewriter.create<affine::AffineApplyOp>(
732  indexOp.getLoc(), idx + acc * std::get<0>(it),
733  ValueRange{std::get<1>(it), newIndex});
734  }
735  rewriter.replaceOp(indexOp, newIndex);
736  }
737 }
738 
739 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
740 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
741 /// that those conditions have been satisfied.
742 static std::optional<SmallVector<Value>>
743 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
744  OpOperand *fusableOpOperand,
745  PatternRewriter &rewriter) {
746  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
747  "preconditions for fuse operation failed");
748  // Check if reshape is expanding or collapsing.
749  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
750  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
751  bool isExpanding = (expandingReshapeOp != nullptr);
752  RankedTensorType expandedType = isExpanding
753  ? expandingReshapeOp.getResultType()
754  : collapsingReshapeOp.getSrcType();
755  RankedTensorType collapsedType = isExpanding
756  ? expandingReshapeOp.getSrcType()
757  : collapsingReshapeOp.getResultType();
758 
759  ExpansionInfo expansionInfo;
760  if (failed(expansionInfo.compute(
761  genericOp, fusableOpOperand,
762  isExpanding ? expandingReshapeOp.getReassociationMaps()
763  : collapsingReshapeOp.getReassociationMaps(),
764  expandedType.getShape(), collapsedType.getShape(), rewriter)))
765  return std::nullopt;
766 
767  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
768  return std::nullopt;
769 
770  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
771  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
772  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
773  }));
774 
775  // Set insertion point to the generic op.
776  OpBuilder::InsertionGuard g(rewriter);
777  rewriter.setInsertionPoint(genericOp);
778 
779  SmallVector<Value> expandedOpOperands;
780  expandedOpOperands.reserve(genericOp.getNumDpsInputs());
781  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
782  if (opOperand == fusableOpOperand) {
783  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
784  : collapsingReshapeOp.getSrc());
785  continue;
786  }
787  if (auto opOperandType =
788  dyn_cast<RankedTensorType>(opOperand->get().getType())) {
789  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
790  RankedTensorType expandedOperandType =
791  getExpandedType(opOperandType, indexingMap, expansionInfo);
792  if (expandedOperandType != opOperand->get().getType()) {
793  // Reshape the operand to get the right type.
794  SmallVector<ReassociationIndices> reassociation =
795  getReassociationForExpansion(indexingMap, expansionInfo);
797  [&](const Twine &msg) {
798  return rewriter.notifyMatchFailure(genericOp, msg);
799  },
800  opOperandType.getShape(), expandedOperandType.getShape(),
801  reassociation,
802  /*isExpandingReshape=*/true)))
803  return std::nullopt;
804  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
805  genericOp.getLoc(), expandedOperandType, opOperand->get(),
806  reassociation));
807  continue;
808  }
809  }
810  expandedOpOperands.push_back(opOperand->get());
811  }
812 
813  Location loc = genericOp.getLoc();
814  SmallVector<Value> outputs;
815  for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
816  AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
817  auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
818  RankedTensorType expandedOutputType =
819  getExpandedType(opOperandType, indexingMap, expansionInfo);
820  if (expandedOutputType != opOperand.get().getType()) {
821  SmallVector<ReassociationIndices> reassociation =
822  getReassociationForExpansion(indexingMap, expansionInfo);
824  [&](const Twine &msg) {
825  return rewriter.notifyMatchFailure(genericOp, msg);
826  },
827  opOperandType.getShape(), expandedOutputType.getShape(),
828  reassociation,
829  /*isExpandingReshape=*/true)))
830  return std::nullopt;
831  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
832  genericOp.getLoc(), expandedOutputType, opOperand.get(),
833  reassociation));
834  } else {
835  outputs.push_back(opOperand.get());
836  }
837  }
838 
839  // The iterator types of the expanded op are all parallel.
840  SmallVector<utils::IteratorType> iteratorTypes(
841  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
842 
843  TypeRange resultTypes = ValueRange(outputs).getTypes();
844  auto fusedOp =
845  rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
846  /*inputs=*/expandedOpOperands, outputs,
847  expandedOpIndexingMaps, iteratorTypes);
848  Region &fusedRegion = fusedOp->getRegion(0);
849  Region &originalRegion = genericOp->getRegion(0);
850  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
851 
852  // Update the index accesses after the expansion.
853  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
854 
855  // Reshape the result values to their original shape if this is a collapsing
856  // reshape folded into its consumer.
857  SmallVector<Value> resultVals;
858  for (OpResult opResult : genericOp->getOpResults()) {
859  int64_t resultNumber = opResult.getResultNumber();
860  if (resultTypes[resultNumber] != opResult.getType()) {
861  SmallVector<ReassociationIndices> reassociation =
863  genericOp.getMatchingIndexingMap(
864  genericOp.getDpsInitOperand(resultNumber)),
865  expansionInfo);
866  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
867  genericOp.getLoc(), opResult.getType(),
868  fusedOp->getResult(resultNumber), reassociation));
869  } else {
870  resultVals.push_back(fusedOp->getResult(resultNumber));
871  }
872  }
873  // Assuming a single result.
874  return resultVals;
875 }
876 
877 namespace {
878 
879 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
880 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
881 /// in the consumer is expanded.
882 class FoldWithProducerReshapeOpByExpansion
883  : public OpRewritePattern<GenericOp> {
884 public:
885  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
886  ControlFusionFn foldReshapes,
887  PatternBenefit benefit = 1)
888  : OpRewritePattern<GenericOp>(context, benefit),
889  controlFoldingReshapes(std::move(foldReshapes)) {}
890 
891  LogicalResult matchAndRewrite(GenericOp genericOp,
892  PatternRewriter &rewriter) const override {
893  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
894  tensor::CollapseShapeOp reshapeOp =
895  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
896  if (!reshapeOp)
897  continue;
898  // Fold only if
899  // - The tensor reshape op is folding.
900  // - All constraints of fusing with reshape by expansion are met.
901  if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
902  (!controlFoldingReshapes(opOperand)))
903  continue;
904 
905  std::optional<SmallVector<Value>> replacementValues =
906  fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
907  if (!replacementValues)
908  return failure();
909  rewriter.replaceOp(genericOp, *replacementValues);
910  return success();
911  }
912  return failure();
913  }
914 
915 private:
916  ControlFusionFn controlFoldingReshapes;
917 };
918 
919 /// Pattern to fold a tensor.expand_shape op with its producer generic op
920 /// by expanding the dimensionality of the loop in the producer op.
921 struct FoldReshapeWithGenericOpByExpansion
922  : public OpRewritePattern<tensor::ExpandShapeOp> {
923 
924  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
925  ControlFusionFn foldReshapes,
926  PatternBenefit benefit = 1)
927  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
928  controlFoldingReshapes(std::move(foldReshapes)) {}
929 
930  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
931  PatternRewriter &rewriter) const override {
932  // Fold only if all constraints of fusing with reshape by expansion are met.
933  auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
934  if (!producerResult) {
935  return rewriter.notifyMatchFailure(reshapeOp,
936  "source not produced by an operation");
937  }
938 
939  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
940  if (!producer) {
941  return rewriter.notifyMatchFailure(reshapeOp,
942  "producer not a generic op");
943  }
944 
946  producer,
947  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
948  return rewriter.notifyMatchFailure(
949  reshapeOp, "failed preconditions of fusion with producer generic op");
950  }
951 
952  if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
953  return rewriter.notifyMatchFailure(reshapeOp,
954  "fusion blocked by control function");
955  }
956 
957  std::optional<SmallVector<Value>> replacementValues =
959  producer, reshapeOp,
960  producer.getDpsInitOperand(producerResult.getResultNumber()),
961  rewriter);
962  if (!replacementValues) {
963  return rewriter.notifyMatchFailure(reshapeOp,
964  "fusion by expansion failed");
965  }
966 
967  // Find the replacement for the reshape op. Since the replacements have the
968  // same type as the returns of the original generic op, the consumer reshape
969  // op can be replaced by the source of the collapse_shape op that defines
970  // the replacement.
971  Value reshapeReplacement =
972  (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
973  .getResultNumber()];
974  if (auto collapseOp =
975  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
976  reshapeReplacement = collapseOp.getSrc();
977  }
978  rewriter.replaceOp(reshapeOp, reshapeReplacement);
979  rewriter.replaceOp(producer, *replacementValues);
980  return success();
981  }
982 
983 private:
984  ControlFusionFn controlFoldingReshapes;
985 };
986 } // namespace
987 
988 //===---------------------------------------------------------------------===//
989 // Methods and patterns to fuse reshape with linalg.generic operations by
990 // contraction of dimensions.
991 //===---------------------------------------------------------------------===//
992 
993 /// For a given list of indices in the range of the `indexingMap` that are
994 /// folded, return the indices of the corresponding domain. Return
995 /// `std::nullopt` on failure. Ensures that all the elements of the returned
996 /// reassociation are distinct.
999  ReassociationIndicesRef rangeReassociation) {
1000  assert(indexingMap.isProjectedPermutation() &&
1001  "expected projected permutation");
1002 
1003  ReassociationIndices domainReassociation = llvm::to_vector<4>(
1004  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1005  return indexingMap.getResults()[pos]
1006  .cast<AffineDimExpr>()
1007  .getPosition();
1008  }));
1009  // The projected permutation semantics ensures that there is no repetition of
1010  // the domain indices.
1011  return domainReassociation;
1012 }
1013 
1014 /// For a given `dimSequence`, check if the sequence is conserved in the
1015 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1016 /// Non-existence of the sequence returns true as well.
1018  ReassociationIndicesRef dimSequence) {
1019  assert(!dimSequence.empty() &&
1020  "expected non-empty list for dimension sequence");
1021  assert(indexingMap.isProjectedPermutation() &&
1022  "expected indexing map to be projected permutation");
1023 
1024  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1025  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1026 
1027  unsigned dimSequenceStart = dimSequence[0];
1028  for (const auto &expr : enumerate(indexingMap.getResults())) {
1029  unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
1030  // 1. Check if this start of the sequence.
1031  if (dimInMapStart == dimSequenceStart) {
1032  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1033  return false;
1034  // 1a. Check if sequence is preserved.
1035  for (const auto &dimInSequence : enumerate(dimSequence)) {
1036  unsigned dimInMap =
1037  indexingMap.getResult(expr.index() + dimInSequence.index())
1038  .cast<AffineDimExpr>()
1039  .getPosition();
1040  if (dimInMap != dimInSequence.value())
1041  return false;
1042  }
1043  // Found the sequence. Projected permutation
1044  // enforces that all AffineDimExprs in the result are unique, so no
1045  // further checks are needed.
1046  return true;
1047  }
1048  // 2. If position in the expr (which is of type AffineDimExpr) is part
1049  // of sequence, return false here. This implies the entire sequence does not
1050  // exist in the indexing map.
1051  if (sequenceElements.count(dimInMapStart))
1052  return false;
1053  }
1054  // 3. No element of sequence found. Return true.
1055  return true;
1056 }
1057 
1060  return llvm::all_of(maps, [&](AffineMap map) {
1061  return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1062  return isDimSequencePreserved(map, dimSequence);
1063  });
1064  });
1065 }
1066 
1067 // Return the list of dimensions of the iteration domain that can be
1068 // collapsed to allow for fusion with the a producer that is an expand_shape
1069 // operation. If all dimensions created by expansion can be collapsed in the
1070 // iteration space then the reshape is defunct.
1071 //
1072 // Example:
1073 //
1074 // ```mlir
1075 // #map = affine_map<(d0, d1) -> (d0, d1)>
1076 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1077 // %2 = tensor.empty [..] : tensor<?x4xf32>
1078 // %3 = linalg.generic {
1079 // indexing_maps = [#map, #map],
1080 // iterator_types = ["parallel" ,"parallel"]}
1081 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1082 // ```
1083 //
1084 // can be fused by collapsing the dimensions of the iteration space.
1085 //
1086 // ```mlir
1087 // #map = affine_map<(d0) -> (d0)>
1088 // %2 = tensor.empty [..] : tensor<?xf32>
1089 // %3 = linalg.generic {
1090 // indexing_maps = [#map, #map],
1091 // iterator_types = ["parallel"]}
1092 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1093 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1094 // ```
1095 //
1096 // In the following example,
1097 //
1098 // ```mlir
1099 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1100 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1101 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1102 // %2 = tensor.empty [..] : tensor<4x?xf32>
1103 // %2 = linalg.generic {
1104 // indexing_maps = [#map0, #map1],
1105 // iterator_types = ["parallel" ,"parallel"]}
1106 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1107 // ```
1108 //
1109 // the reshape cannot be fused with the generic op by collapsing the op
1110 // dimensions since the indexing maps will have to contain mods and divs
1111 // to preserve the accesses pattern. When no dimensions of the iteration
1112 // space are collapsable and empty vector is returned.
1114 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1115  ArrayRef<ReassociationIndices> reassociation) {
1116  // Some basic checks for this fusion to be valid.
1117  if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
1118  return {};
1119 
1120  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1121  return map.isProjectedPermutation();
1122  })) {
1123  return {};
1124  }
1125 
1126  // Compute all the loops with the reduction iterator types.
1127  SmallVector<unsigned> reductionDims;
1128  genericOp.getReductionDims(reductionDims);
1129 
1130  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1131  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1132  auto iteratorTypes = genericOp.getIteratorTypesArray();
1133  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1134  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1135  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1136 
1137  // Ignore dims that are not folded.
1138  if (foldedRangeDims.size() == 1)
1139  continue;
1140 
1141  ReassociationIndices foldedIterationSpaceDims =
1142  getDomainReassociation(indexingMap, foldedRangeDims);
1143 
1144  // Check that the folded iteration dims do not contain already processed
1145  // dims.
1146  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1147  return processedIterationDims.count(dim);
1148  }))
1149  continue;
1150 
1151  // Check that all folded iterator types are all parallel or all reductions.
1152  utils::IteratorType startIteratorType =
1153  iteratorTypes[foldedIterationSpaceDims[0]];
1154  if (!isParallelIterator(startIteratorType) &&
1155  !isReductionIterator(startIteratorType))
1156  continue;
1157  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1158  return iteratorTypes[dim] != startIteratorType;
1159  }))
1160  continue;
1161 
1162  // If the folded dimensions correspond to a "reduction" iterator type,
1163  // the folded dimensions need to be "in-order". Strictly speaking this is
1164  // not necessary, for reductions that are associative and commutative, but
1165  // using a more strict definition of reduction for now.
1166  if (isReductionIterator(startIteratorType)) {
1167  bool isContiguous = false;
1168  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1169  // Move window in `reductionDims` to start of the folded iteration dims.
1170  if (startDim.value() != foldedIterationSpaceDims[0])
1171  continue;
1172  // If sizes doesnt match, trivial not contiguous. This condition should
1173  // not be hit.
1174  if (startDim.index() + foldedIterationSpaceDims.size() >
1175  reductionDims.size())
1176  break;
1177  // Check that the contiguity is maintained.
1178  isContiguous = true;
1179  for (const auto &foldedDim :
1180  llvm::enumerate(foldedIterationSpaceDims)) {
1181  if (reductionDims[foldedDim.index() + startDim.index()] !=
1182  foldedDim.value()) {
1183  isContiguous = false;
1184  break;
1185  }
1186  }
1187  break;
1188  }
1189  if (!isContiguous)
1190  continue;
1191  }
1192 
1193  // Check that the sequence is preserved in all indexing maps.
1194  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1195  [&](AffineMap indexingMap) {
1196  return !isDimSequencePreserved(indexingMap,
1197  foldedIterationSpaceDims);
1198  }))
1199  continue;
1200 
1201  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1202  foldedIterationSpaceDims.end());
1203  iterationSpaceReassociation.emplace_back(
1204  std::move(foldedIterationSpaceDims));
1205  }
1206 
1207  return iterationSpaceReassociation;
1208 }
1209 
1210 /// Helper class to carry state while collapsing the `linalg.generic` op.
1211 namespace {
1212 class CollapsingInfo {
1213 public:
1214  LogicalResult initialize(unsigned origNumLoops,
1215  ArrayRef<ReassociationIndices> foldedIterationDims) {
1216  llvm::SmallDenseSet<int64_t, 4> processedDims;
1217  // Find all the dims that are folded.
1218  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1219  if (foldedIterationDim.empty())
1220  continue;
1221  // If the folded dims contain dims already folded, that's illegal
1222  // specification. Repetition within a list is also illegal.
1223  for (auto dim : foldedIterationDim) {
1224  if (dim >= origNumLoops)
1225  return failure();
1226  if (processedDims.count(dim))
1227  return failure();
1228  processedDims.insert(dim);
1229  }
1230  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1231  foldedIterationDim.end());
1232  }
1233  if (processedDims.size() > origNumLoops)
1234  return failure();
1235 
1236  // Add all the preserved dims of the original op as single
1237  // elements to `collapsedOpToOrigOpIterationDim`.
1238  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1239  if (processedDims.count(dim))
1240  continue;
1241  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1242  }
1243 
1244  llvm::sort(collapsedOpToOrigOpIterationDim,
1246  return lhs[0] < rhs[0];
1247  });
1248  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1249  for (const auto &foldedDims :
1250  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1251  for (const auto &dim : enumerate(foldedDims.value()))
1252  origOpToCollapsedOpIterationDim[dim.value()] =
1253  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1254  }
1255  return success();
1256  }
1257 
1258  /// Return mapping from collapsed loop domain to original loop domain.
1259  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1260  return collapsedOpToOrigOpIterationDim;
1261  }
1262 
1263  /// Return mapping from original loop domain to collapsed loop domain. The
1264  /// mapping is a pair. First value is the dimension in the collapsed loop that
1265  /// the original loop is mapped to. Second is the relative position in folded
1266  /// list of this domain. For example if the original loop domain is 3D, and
1267  /// the collapsed loop domain is folding all of it, i.e.
1268  ///
1269  /// ```
1270  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1271  /// ```
1272  ///
1273  /// then
1274  ///
1275  /// ```
1276  /// origOpToCollapsedOpMapping[0] = {0, 0};
1277  /// origOpToCollapsedOpMapping[1] = {0, 1};
1278  /// origOpToCollapsedOpMapping[2] = {0, 2};
1279  /// origOpToCollapsedOpMapping[3] = {1, 0};
1280  /// origOpToCollapsedOpMapping[4] = {1, 1};
1281  /// ```
1282  ///
1283  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1284  return origOpToCollapsedOpIterationDim;
1285  }
1286 
1287  /// Return the collapsed op iteration domain rank.
1288  unsigned getCollapsedOpIterationRank() const {
1289  return collapsedOpToOrigOpIterationDim.size();
1290  }
1291 
1292 private:
1293  /// Map from the iteration domain index in collapsed op to the iteration
1294  /// domain indices in the original op.
1295  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1296 
1297  /// Map from iteration domain index in the original op to the iteration domain
1298  /// index in the collapsed op.
1299  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1300 };
1301 } // namespace
1302 
1303 /// Get the iterator types for the collapsed operation given the original
1304 /// iterator types and collapsed dimensions.
1307  const CollapsingInfo &collapsingInfo) {
1308  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1309  for (ReassociationIndicesRef foldedIterDims :
1310  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1311  assert(!foldedIterDims.empty() &&
1312  "reassociation indices expected to have non-empty sets");
1313  // Just pick the iterator type of the first folded dim. Pre-condition checks
1314  // expected to have checked that iterator types of all folded dimensions are
1315  // the same.
1316  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1317  }
1318  return collapsedIteratorTypes;
1319 }
1320 
1321 /// Compute the indexing map in the collapsed op that corresponds to the given
1322 /// `indexingMap` of the original operation.
1323 static AffineMap
1325  const CollapsingInfo &collapsingInfo) {
1326  MLIRContext *context = indexingMap.getContext();
1327  assert(indexingMap.isProjectedPermutation() &&
1328  "expected indexing map to be projected permutation");
1329  SmallVector<AffineExpr> resultExprs;
1330  auto origOpToCollapsedOpMapping =
1331  collapsingInfo.getOrigOpToCollapsedOpMapping();
1332  for (auto expr : indexingMap.getResults()) {
1333  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1334  // If the dim is not the first of the collapsed dim, do nothing.
1335  if (origOpToCollapsedOpMapping[dim].second != 0)
1336  continue;
1337  // The next n-dims are guaranteed to be collapsed. So just use the
1338  // iteration dimension of the collapsed op.
1339  resultExprs.push_back(
1340  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1341  }
1342  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1343  resultExprs, context);
1344 }
1345 
1346 /// Return the `reassociation` indices to use to collapse the operand when the
1347 /// iteration space of a generic op is collapsed.
1350  const CollapsingInfo &collapsingInfo) {
1351  unsigned counter = 0;
1352  SmallVector<ReassociationIndices> operandReassociation;
1353  auto origOpToCollapsedOpMapping =
1354  collapsingInfo.getOrigOpToCollapsedOpMapping();
1355  auto collapsedOpToOrigOpMapping =
1356  collapsingInfo.getCollapsedOpToOrigOpMapping();
1357  while (counter < indexingMap.getNumResults()) {
1358  unsigned dim =
1359  indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1360  // This is the start of a collapsed dimensions of the iteration that
1361  // is gauranteed to be preserved in the indexing map. The number of folded
1362  // dims is obtained from the collapsed op to original op mapping.
1363  unsigned numFoldedDims =
1364  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1365  .size();
1366  if (origOpToCollapsedOpMapping[dim].second == 0) {
1367  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1368  operandReassociation.emplace_back(range.begin(), range.end());
1369  }
1370  counter += numFoldedDims;
1371  }
1372  return operandReassociation;
1373 }
1374 
1375 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1376 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1377  OpOperand *opOperand,
1378  const CollapsingInfo &collapsingInfo,
1379  OpBuilder &builder) {
1380  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
1381  SmallVector<ReassociationIndices> operandReassociation =
1382  getOperandReassociation(indexingMap, collapsingInfo);
1383 
1384  // If the number of entries in the reassocation for the operand is same as the
1385  // number of results of the indexing map, then nothing to do for this operand.
1386  Value operand = opOperand->get();
1387  if (operandReassociation.size() == indexingMap.getNumResults())
1388  return operand;
1389 
1390  // Insert a reshape to collapse the dimensions.
1391  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1392  loc, operand, operandReassociation);
1393  return reshapeOp.getResult();
1394 }
1395 
1396 /// Modify the `linalg.index` operations in the original generic op, to its
1397 /// value in the collapsed operation.
1399  const CollapsingInfo &collapsingInfo,
1400  ValueRange loopRange,
1401  RewriterBase &rewriter) {
1402  OpBuilder::InsertionGuard g(rewriter);
1403  rewriter.setInsertionPointToStart(block);
1404 
1405  // Collect all the original index ops.
1406  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1407 
1408  // For each folded dimension list resolve the original induction variable
1409  // values in terms of the folded dimension induction variable.
1410  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1411  // can be inverted to
1412  // i2 = i_{folded} % d2
1413  // i1 = (i_{folded} / d2) % d1
1414  // i0 = i_{folded} / (d1 * d2)
1415  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1416  for (auto foldedDims :
1417  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1418  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1419  Value newIndexVal =
1420  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1421  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1422  indexReplacementVals[dim] =
1423  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1424  newIndexVal =
1425  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1426  }
1427  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1428  }
1429 
1430  for (auto indexOp : indexOps) {
1431  auto dim = indexOp.getDim();
1432  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1433  }
1434 }
1435 
1436 /// Implementation of fusion with reshape operation by collapsing dimensions.
1438  GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1439  RewriterBase &rewriter) {
1440  // Bail on trivial no-op cases.
1441  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1442  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1443  return foldedDims.size() <= 1;
1444  }))
1445  return failure();
1446 
1447  CollapsingInfo collapsingInfo;
1448  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1449  foldedIterationDims))) {
1450  return rewriter.notifyMatchFailure(
1451  genericOp, "illegal to collapse specified dimensions");
1452  }
1453 
1454  // Bail on non-canonical ranges.
1455  SmallVector<Range> loopRanges =
1456  cast<LinalgOp>(genericOp.getOperation())
1457  .createLoopRanges(rewriter, genericOp.getLoc());
1458  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1459  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1460  return cast<IntegerAttr>(attr).getInt() == value;
1461  llvm::APInt actual;
1462  return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1463  actual.getSExtValue() == value;
1464  };
1465  if (!llvm::all_of(loopRanges, [&](Range range) {
1466  return opFoldIsConstantValue(range.offset, 0) &&
1467  opFoldIsConstantValue(range.stride, 1);
1468  })) {
1469  return rewriter.notifyMatchFailure(
1470  genericOp,
1471  "expected all loop ranges to have zero start and unit stride");
1472  }
1473 
1474  // Get the iterator types for the operand.
1476  genericOp.getIteratorTypesArray(), collapsingInfo);
1477 
1478  // Get the indexing maps.
1479  auto indexingMaps = llvm::to_vector(
1480  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
1481  return getCollapsedOpIndexingMap(map, collapsingInfo);
1482  }));
1483 
1484  Location loc = genericOp->getLoc();
1485 
1486  // Get the input operands.
1487  auto inputOperands = llvm::to_vector(llvm::map_range(
1488  genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
1489  return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1490  rewriter);
1491  }));
1492 
1493  // Get the output operands and result types.
1494  SmallVector<Type> resultTypes;
1495  SmallVector<Value> outputOperands;
1496  resultTypes.reserve(genericOp.getNumDpsInits());
1497  outputOperands.reserve(genericOp.getNumDpsInits());
1498  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
1499  Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
1500  collapsingInfo, rewriter);
1501  outputOperands.push_back(newOutput);
1502  resultTypes.push_back(newOutput.getType());
1503  }
1504 
1505  // Create the generic op.
1506  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1507  loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1508  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1509  Block *origOpBlock = &genericOp->getRegion(0).front();
1510  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1511  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1512  collapsedOpBlock->getArguments());
1513 
1514  if (collapsedGenericOp.hasIndexSemantics()) {
1515  // Collect the loop range of the generic op.
1516  OpBuilder::InsertionGuard g(rewriter);
1517  rewriter.setInsertionPoint(collapsedGenericOp);
1518  SmallVector<Value> loopBound =
1519  llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
1520  return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1521  }));
1523  &collapsedGenericOp->getRegion(0).front(),
1524  collapsingInfo, loopBound, rewriter);
1525  }
1526 
1527  // Insert expanding reshape for the result to get back the original result
1528  // type.
1529  SmallVector<Value> results;
1530  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1531  Value collapsedOpResult =
1532  collapsedGenericOp->getResult(originalResult.index());
1533  auto originalResultType =
1534  cast<ShapedType>(originalResult.value().getType());
1535  auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1536  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1537  AffineMap indexingMap =
1538  genericOp.getIndexingMapMatchingResult(originalResult.value());
1539  SmallVector<ReassociationIndices> reassociation =
1540  getOperandReassociation(indexingMap, collapsingInfo);
1541  Value result = rewriter.create<tensor::ExpandShapeOp>(
1542  loc, originalResultType, collapsedOpResult, reassociation);
1543  results.push_back(result);
1544  } else {
1545  results.push_back(collapsedOpResult);
1546  }
1547  }
1548  return results;
1549 }
1550 
1551 namespace {
1552 
1553 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1554 /// contracting dimensions of the loop.
1555 class FoldWithProducerReshapeOpByCollapsing
1556  : public OpRewritePattern<GenericOp> {
1557 public:
1558  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1559  ControlFusionFn foldReshapes,
1560  PatternBenefit benefit = 1)
1561  : OpRewritePattern<GenericOp>(context, benefit),
1562  controlFoldingReshapes(std::move(foldReshapes)) {}
1563 
1564  LogicalResult matchAndRewrite(GenericOp genericOp,
1565  PatternRewriter &rewriter) const override {
1566  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1567  tensor::ExpandShapeOp reshapeOp =
1568  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1569  if (!reshapeOp)
1570  continue;
1571 
1572  SmallVector<ReassociationIndices> collapsableIterationDims =
1573  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1574  reshapeOp.getReassociationIndices());
1575  if (collapsableIterationDims.empty() ||
1576  !controlFoldingReshapes(&opOperand)) {
1577  continue;
1578  }
1579 
1580  std::optional<SmallVector<Value>> replacements =
1581  collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1582  rewriter);
1583  if (!replacements) {
1584  return rewriter.notifyMatchFailure(
1585  genericOp, "failed to do the fusion by collapsing transformation");
1586  }
1587 
1588  rewriter.replaceOp(genericOp, *replacements);
1589  return success();
1590  }
1591  return failure();
1592  }
1593 
1594 private:
1595  ControlFusionFn controlFoldingReshapes;
1596 };
1597 
1598 /// Pattern to collapse dimensions.
1599 class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
1600 public:
1601  CollapseLinalgDimensions(MLIRContext *context,
1602  GetCollapsableDimensionsFn collapseDimensions,
1603  PatternBenefit benefit = 1)
1604  : OpRewritePattern<GenericOp>(context, benefit),
1605  controlCollapseDimension(std::move(collapseDimensions)) {}
1606 
1607  LogicalResult matchAndRewrite(GenericOp genericOp,
1608  PatternRewriter &rewriter) const override {
1609  SmallVector<ReassociationIndices> collapsableIterationDims =
1610  controlCollapseDimension(genericOp);
1611  if (collapsableIterationDims.empty())
1612  return failure();
1613 
1614  // Check if the specified list of dimensions to collapse is a valid list.
1615  if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
1616  collapsableIterationDims)) {
1617  return rewriter.notifyMatchFailure(
1618  genericOp, "specified dimensions cannot be collapsed");
1619  }
1620 
1621  std::optional<SmallVector<Value>> replacements =
1622  collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1623  rewriter);
1624  if (!replacements) {
1625  return rewriter.notifyMatchFailure(genericOp,
1626  "failed to collapse dimensions");
1627  }
1628  rewriter.replaceOp(genericOp, *replacements);
1629  return success();
1630  }
1631 
1632 private:
1633  GetCollapsableDimensionsFn controlCollapseDimension;
1634 };
1635 
1636 } // namespace
1637 
1638 //===---------------------------------------------------------------------===//
1639 // Methods and patterns that fuse constants with linalg.generic operations.
1640 //===---------------------------------------------------------------------===//
1641 
1642 namespace {
1643 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1644 /// handle cases where the constant is not single-valued.
1645 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1646 public:
1647  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1648  : OpRewritePattern<GenericOp>(context, benefit) {}
1649 
1650  LogicalResult matchAndRewrite(GenericOp genericOp,
1651  PatternRewriter &rewriter) const override {
1652  if (!genericOp.hasTensorSemantics())
1653  return failure();
1654  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1655  Operation *def = opOperand->get().getDefiningOp();
1656  TypedAttr constantAttr;
1657  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1658  {
1659  DenseElementsAttr splatAttr;
1660  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1661  splatAttr.isSplat() &&
1662  splatAttr.getType().getElementType().isIntOrFloat()) {
1663  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1664  return true;
1665  }
1666  }
1667  {
1668  IntegerAttr intAttr;
1669  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1670  constantAttr = intAttr;
1671  return true;
1672  }
1673  }
1674  {
1675  FloatAttr floatAttr;
1676  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1677  constantAttr = floatAttr;
1678  return true;
1679  }
1680  }
1681  return false;
1682  };
1683 
1684  auto resultValue = dyn_cast<OpResult>(opOperand->get());
1685  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1686  continue;
1687 
1688  // The operands and the indexing_maps of the fused operation the same as
1689  // the operands and indexing_maps of the generic operations with the
1690  // values at the constant index dropped.
1691  SmallVector<AffineMap> fusedIndexMaps;
1692  SmallVector<Value> fusedOperands;
1693  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1694  fusedIndexMaps.reserve(genericOp->getNumOperands());
1695  fusedOperands.reserve(genericOp.getNumDpsInputs());
1696  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1697  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1698  if (inputOperand == opOperand)
1699  continue;
1700  Value inputValue = inputOperand->get();
1701  fusedIndexMaps.push_back(
1702  genericOp.getMatchingIndexingMap(inputOperand));
1703  fusedOperands.push_back(inputValue);
1704  fusedLocs.push_back(inputValue.getLoc());
1705  }
1706  for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1707  fusedIndexMaps.push_back(
1708  genericOp.getMatchingIndexingMap(&outputOperand));
1709 
1710  // Check if the operation shapes to loops map is computable.
1711  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1712  return rewriter.notifyMatchFailure(
1713  genericOp, "fused op loop bound computation failed");
1714  }
1715 
1716  // Create a constant scalar value from the splat constant.
1717  Value scalarConstant =
1718  rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
1719 
1720  SmallVector<Value> outputOperands = genericOp.getOutputs();
1721  auto fusedOp = rewriter.create<GenericOp>(
1722  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1723  /*inputs=*/fusedOperands,
1724  /*outputs=*/outputOperands,
1725  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1726  genericOp.getIteratorTypes(),
1727  /*doc=*/nullptr,
1728  /*library_call=*/nullptr);
1729 
1730  // Map the block argument corresponding to the replaced argument with the
1731  // scalar constant.
1732  Region &region = genericOp->getRegion(0);
1733  Block &entryBlock = *region.begin();
1734  IRMapping mapping;
1735  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1736  scalarConstant);
1737  Region &fusedRegion = fusedOp->getRegion(0);
1738  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1739  mapping);
1740  rewriter.replaceOp(genericOp, fusedOp->getResults());
1741  return success();
1742  }
1743  return failure();
1744  }
1745 };
1746 
1747 } // namespace
1748 
1749 //===---------------------------------------------------------------------===//
1750 // Miscellaneous patterns that help fusion.
1751 //===---------------------------------------------------------------------===//
1752 
1753 namespace {
1754 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
1755 /// value of the `outs` operand is not used within the op. This is only
1756 /// implemented for `linalg.generic` operations for now, but should hold for all
1757 /// linalg structured ops.
1758 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1760 
1761  LogicalResult matchAndRewrite(GenericOp op,
1762  PatternRewriter &rewriter) const override {
1763  rewriter.startRootUpdate(op);
1764  bool modifiedOutput = false;
1765  Location loc = op.getLoc();
1766  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1767  if (!op.payloadUsesValueFromOperand(&opOperand)) {
1768  Value operandVal = opOperand.get();
1769  auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
1770  if (!operandType)
1771  continue;
1772 
1773  // If outs is sparse, leave it to the sparse compiler.
1775  continue;
1776 
1777  // If outs is already an `empty` operation, nothing to do.
1778  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
1779  if (definingOp)
1780  continue;
1781  modifiedOutput = true;
1782  SmallVector<OpFoldResult> mixedSizes =
1783  tensor::getMixedSizes(rewriter, loc, operandVal);
1784  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1785  loc, mixedSizes, operandType.getElementType());
1786  op->setOperand(opOperand.getOperandNumber(), emptyTensor);
1787  }
1788  }
1789  if (!modifiedOutput) {
1790  rewriter.cancelRootUpdate(op);
1791  return failure();
1792  }
1793  rewriter.finalizeRootUpdate(op);
1794  return success();
1795  }
1796 };
1797 
1798 /// Fold linalg.fill into linalg.generic
1799 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1801 
1802  LogicalResult matchAndRewrite(GenericOp genericOp,
1803  PatternRewriter &rewriter) const override {
1804  if (!genericOp.hasTensorSemantics())
1805  return failure();
1806  bool fillFound = false;
1807  Block &payload = genericOp.getRegion().front();
1808  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1809  if (!genericOp.payloadUsesValueFromOperand(opOperand))
1810  continue;
1811  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1812  if (!fillOp)
1813  continue;
1814  fillFound = true;
1815  Value fillVal = fillOp.value();
1816  auto resultType =
1817  cast<RankedTensorType>(fillOp.result().getType()).getElementType();
1818  Value convertedVal =
1819  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
1820  /*isUnsignedCast =*/false);
1821  rewriter.replaceAllUsesWith(
1822  payload.getArgument(opOperand->getOperandNumber()), convertedVal);
1823  }
1824  return success(fillFound);
1825  }
1826 };
1827 } // namespace
1828 
1830  RewritePatternSet &patterns,
1831  const ControlFusionFn &controlFoldingReshapes) {
1832  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1833  controlFoldingReshapes);
1834  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1835  controlFoldingReshapes);
1836 }
1837 
1839  RewritePatternSet &patterns,
1840  const ControlFusionFn &controlFoldingReshapes) {
1841  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1842  controlFoldingReshapes);
1843 }
1844 
1846  RewritePatternSet &patterns,
1847  const ControlFusionFn &controlElementwiseOpsFusion) {
1848  auto *context = patterns.getContext();
1849  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1850  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1851  RemoveOutsDependency>(context);
1852  // Add the patterns that clean up dead operands and results.
1854 }
1855 
1857  RewritePatternSet &patterns,
1858  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
1859  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
1860  controlCollapseDimensions);
1861 }
1862 
1863 //===---------------------------------------------------------------------===//
1864 // Passes
1865 //===---------------------------------------------------------------------===//
1866 
1867 namespace {
1868 
1869 /// Pass that fuses generic ops on tensors. Used only for testing.
1870 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1871 // patterns added here heavily depends on the cost function used. Having an
1872 // opinionated pass of this form is not recommended. Deprecate this pass in
1873 // favor of test passes that check the functionality of each of the patterns
1874 // added here individually.
1875 struct LinalgElementwiseOpFusionPass
1876  : public impl::LinalgElementwiseOpFusionBase<
1877  LinalgElementwiseOpFusionPass> {
1878  void runOnOperation() override {
1879  Operation *op = getOperation();
1880  MLIRContext *context = op->getContext();
1881  RewritePatternSet patterns(context);
1882 
1883  // Add folding with reshape by expansion patterns.
1884  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
1885  Operation *producer = fusedOperand->get().getDefiningOp();
1886  return producer && producer->hasOneUse();
1887  };
1888 
1889  // Add elementwise op fusion patterns.
1890  populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1891  populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1892 
1893  // General canonicalization patterns.
1894  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1895  GenericOp::getCanonicalizationPatterns(patterns, context);
1896  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1897  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1898  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1899  patterns);
1900 
1901  // Add constant folding patterns.
1902  populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1903 
1904  // Use TopDownTraversal for compile time reasons
1905  GreedyRewriteConfig grc;
1906  grc.useTopDownTraversal = true;
1907  (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
1908  }
1909 };
1910 
1911 } // namespace
1912 
1914  return std::make_unique<LinalgElementwiseOpFusionPass>();
1915 }
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
U cast() const
Definition: AffineExpr.h:293
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
MLIRContext * getContext() const
Definition: AffineMap.cpp:284
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:358
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:534
unsigned getNumSymbols() const
Definition: AffineMap.cpp:341
unsigned getNumDims() const
Definition: AffineMap.cpp:337
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
unsigned getNumResults() const
Definition: AffineMap.cpp:345
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:354
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:570
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:495
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:564
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator begin()
Definition: Block.h:136
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:186
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:202
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
This is a value defined by a result of an operation.
Definition: Value.h:448
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:600
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:591
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1479
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
FailureOr< SmallVector< Value > > collapseGenericOpIterationDims(GenericOp genericOp, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic operation.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
std::function< SmallVector< ReassociationIndices >(linalg::GenericOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:1510
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Definition: SparseTensor.h:151
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:144
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:60
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:168
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:680
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:725
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
std::unique_ptr< Pass > createLinalgElementwiseOpFusionPass()
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition: Transforms.h:474
llvm::DenseMap< Value, Value > replacements
Definition: Transforms.h:476