MLIR  16.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
26 #include <utility>
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
30 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSION
31 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 
37 //===---------------------------------------------------------------------===//
38 // Methods and patterns that fuse elementwise `linalg.generic` operations.
39 //===---------------------------------------------------------------------===//
40 
41 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
42 /// the `producer` to use in the fused operation given the indexing map of the
43 /// result of the producer in the consumer.
45  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
46  AffineMap fusedConsumerArgIndexMap) {
47  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
48  // from consumer loop -> consumer arg tensor index/producer result tensor
49  // index. The fused loop is same as the consumer loop. For each producer arg
50  // the indexing map to be computed is a map from consumer loop -> producer
51  // arg tensor index.
52  // producerResultIndexMap is a map from producer loop -> tensor index.
53  // Compute the inverse to get map from tensor index -> producer loop.
54  // The inverse is a map from producer result tensor index -> producer loop.
55  AffineMap invProducerResultIndexMap =
56  inversePermutation(producerResultIndexMap);
57  assert(invProducerResultIndexMap &&
58  "expected producer result indexing map to be invertible");
59 
60  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
61  // argMap is a map from producer loop -> producer arg tensor index.
62  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
63 
64  // Compose argMap with invProducerResultIndexMap to get a map from
65  // producer result tensor index -> producer arg tensor index.
66  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
67 
68  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
69  // consumer loop/ fused loop -> producer arg tensor index.
70  return t1.compose(fusedConsumerArgIndexMap);
71 }
72 
73 /// Conditions for elementwise fusion of generic operations.
75  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
76  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
77 
78  // Check producer and consumer are generic ops.
79  if (!producer || !consumer)
80  return false;
81 
82  // Consumer can have mixed semantics, just check operand itself has tensor
83  // type. Producer must have full tensor semantics to avoid potential
84  // aliasing between producer and consumer memrefs.
85  if (!producer.hasTensorSemantics() ||
86  !fusedOperand->get().getType().isa<RankedTensorType>())
87  return false;
88 
89  // Verify that
90  // - the producer has all "parallel" iterator type.
91  if (producer.getNumParallelLoops() != producer.getNumLoops())
92  return false;
93 
94  // Only allow fusing the producer of an input operand for now.
95  // TODO: allow fusing the producer of an output operand.
96  if (!consumer.isDpsInput(fusedOperand))
97  return false;
98 
99  // Get the consumer index map. The number of results of the consumer index
100  // map must match the number of loops of the producer.
101  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
102  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
103  return false;
104 
105  // Finally the index_map for the result must be invertible. For now just
106  // verify it is a permutation.
107  AffineMap producerResultIndexMap =
108  producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
109  if (!producerResultIndexMap.isPermutation())
110  return false;
111 
112  // Ensure that the fusion does not remove size information required to
113  // get the loop bounds. For non-reduction generics, this is trivially the
114  // case due to the output operand. For reductions, we need to check that after
115  // the fusion, each loop dimension has at least one input that defines it.
116  if ((consumer.getNumReductionLoops())) {
117  BitVector coveredDims(consumer.getNumLoops(), false);
118 
119  auto addToCoveredDims = [&](AffineMap map) {
120  for (auto result : map.getResults())
121  if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
122  coveredDims[dimExpr.getPosition()] = true;
123  };
124 
125  for (auto pair :
126  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
127  Value operand = std::get<0>(pair);
128  if (operand == fusedOperand->get())
129  continue;
130  AffineMap operandMap = std::get<1>(pair);
131  addToCoveredDims(operandMap);
132  }
133 
134  for (OpOperand *operand : producer.getDpsInputOperands()) {
135  AffineMap newIndexingMap =
137  operand, producerResultIndexMap, consumerIndexMap);
138  addToCoveredDims(newIndexingMap);
139  }
140  if (!coveredDims.all())
141  return false;
142  }
143 
144  return true;
145 }
146 
147 /// Generate the region of the fused tensor operation. The region of the fused
148 /// op must be empty.
150  RewriterBase &rewriter, GenericOp fusedOp,
151  AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
152  unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
153  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
154  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
155  // Build the region of the fused op.
156  Block &producerBlock = producer->getRegion(0).front();
157  Block &consumerBlock = consumer->getRegion(0).front();
158  Block *fusedBlock = new Block();
159  fusedOp.getRegion().push_back(fusedBlock);
160  BlockAndValueMapping mapper;
161  OpBuilder::InsertionGuard guard(rewriter);
162  rewriter.setInsertionPointToStart(fusedBlock);
163 
164  // 2. Add an index operation for every fused loop dimension and use the
165  // `consumerToProducerLoopsMap` to map the producer indices.
166  if (producer.hasIndexSemantics()) {
167  // Add an index operation for every fused loop dimension.
168  unsigned numFusedOpLoops =
169  std::max(producer.getNumLoops(), consumer.getNumLoops());
170  SmallVector<Value> fusedIndices;
171  fusedIndices.reserve(numFusedOpLoops);
172  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
173  std::back_inserter(fusedIndices), [&](uint64_t dim) {
174  return rewriter.create<IndexOp>(producer.getLoc(), dim);
175  });
176  for (IndexOp indexOp :
177  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
178  Value newIndex = rewriter.create<mlir::AffineApplyOp>(
179  producer.getLoc(),
180  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
181  mapper.map(indexOp.getResult(), newIndex);
182  }
183  }
184  // TODO: allow fusing the producer of an output operand.
185  assert(consumer.isDpsInput(fusedOperand) &&
186  "expected producer of input operand");
187  // 3. Consumer input operands up to consumerIdx (exclusive).
188  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
189  fusedOperand->getOperandNumber())) // input assumption.
190  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
191 
192  // Replacing consumerIdx requires getting the cloned, yielded, value from
193  // the (cloned) producer block. This happens in step 9.
194 
195  // 4. Splice in producer's input operands.
196  for (BlockArgument bbArg :
197  producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
198  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199 
200  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201  for (BlockArgument bbArg :
202  consumerBlock.getArguments()
203  .take_front(consumer.getNumDpsInputs())
204  .drop_front(fusedOperand->getOperandNumber() + 1))
205  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206 
207  // 6. All of the producer's output operands
208  for (const auto &bbArg : llvm::enumerate(
209  producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
210  if (!preservedProducerResults.count(bbArg.index()))
211  continue;
212  mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
213  bbArg.value().getLoc()));
214  }
215 
216  // 7. All of consumer's output operands.
217  for (BlockArgument bbArg :
218  consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
219  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
220 
221  // 8. Clone all producer operations except for the yield and index operations
222  // to the fused operation.
223  for (auto &op : producerBlock.without_terminator()) {
224  if (!isa<IndexOp>(op))
225  rewriter.clone(op, mapper);
226  }
227  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
228  // forward the yield operand.
229  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
230  unsigned producerResultNumber =
231  fusedOperand->get().cast<OpResult>().getResultNumber();
232  Value replacement =
233  mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
234 
235  // Sanity checks, if replacement is not already in the mapper then it must be
236  // produced outside.
237  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
238  if (auto bb = replacement.dyn_cast<BlockArgument>())
239  assert(bb.getOwner() != &producerBlock &&
240  "yielded block argument must have been mapped");
241  else
242  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
243  "yielded value must have been mapped");
244  }
245  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
246  replacement);
247  // 10. Clone operations from the consumer to the fused op.
248  for (auto &op : consumerBlock.without_terminator())
249  rewriter.clone(op, mapper);
250 
251  // 11. Include the final yield (which is the remapped values for all the
252  // yield)
253  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
254  SmallVector<Value> fusedYieldValues;
255  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
256  consumerYieldOp.getNumOperands());
257  for (const auto &producerYieldVal :
258  llvm::enumerate(producerYieldOp.getOperands())) {
259  if (preservedProducerResults.count(producerYieldVal.index()))
260  fusedYieldValues.push_back(
261  mapper.lookupOrDefault(producerYieldVal.value()));
262  }
263  for (auto consumerYieldVal : consumerYieldOp.getOperands())
264  fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
265  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
266 
267  // Sanity checks.
268  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
269  "Ill-formed GenericOp region");
270 }
271 
274  OpOperand *fusedOperand) {
275  assert(areElementwiseOpsFusable(fusedOperand) &&
276  "expected elementwise operation pre-conditions to pass");
277  auto producerResult = fusedOperand->get().cast<OpResult>();
278  auto producer = cast<GenericOp>(producerResult.getOwner());
279  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
280  // TODO: allow fusing the producer of an output operand.
281  assert(consumer.isDpsInput(fusedOperand) &&
282  "expected producer of input operand");
283  /// Find the results of the producer that have uses outside of the consumer.
284  llvm::SmallDenseSet<int> preservedProducerResults;
285  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
286  auto outputOperand = producer.getDpsInitOperand(producerResult.index());
287  if (producer.payloadUsesValueFromOperand(outputOperand) ||
288  !producer.canOpOperandsBeDropped(outputOperand) ||
289  llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
290  return user != consumer.getOperation();
291  })) {
292  preservedProducerResults.insert(producerResult.index());
293  }
294  }
295 
296  // Compute the fused operands list and indexing maps.
297  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
298  SmallVector<Type> fusedResultTypes;
299  SmallVector<AffineMap> fusedIndexMaps;
300  fusedInputOperands.reserve(producer.getNumDpsInputs() +
301  consumer.getNumDpsInputs());
302  fusedOutputOperands.reserve(preservedProducerResults.size() +
303  consumer.getNumDpsInits());
304  fusedResultTypes.reserve(preservedProducerResults.size() +
305  consumer.getNumDpsInits());
306  fusedIndexMaps.reserve(producer->getNumOperands() +
307  consumer->getNumOperands());
308  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
309  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
310  auto consumerInputs = consumer.getDpsInputOperands();
311  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
312  return operand == fusedOperand;
313  });
314  assert(it != consumerInputs.end() && "expected to find the consumer operand");
315  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
316  fusedInputOperands.push_back(opOperand->get());
317  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
318  }
319  // 4. Splice in producer's input operands/maps.
320  AffineMap producerResultIndexMap =
321  producer.getIndexingMapMatchingResult(producerResult);
322  for (OpOperand *opOperand : producer.getDpsInputOperands()) {
323  fusedInputOperands.push_back(opOperand->get());
324  // Compute indexing maps for the producer args in the fused operation.
326  opOperand, producerResultIndexMap,
327  consumer.getMatchingIndexingMap(fusedOperand));
328  fusedIndexMaps.push_back(map);
329  }
330  // 5. Remaining consumer's input operands/maps (drop past index
331  // `consumerIdx`).
332  for (OpOperand *opOperand :
333  llvm::make_range(std::next(it), consumerInputs.end())) {
334  fusedInputOperands.push_back(opOperand->get());
335  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
336  }
337 
338  // 6. Collect all of the producer outputs.
339  for (const auto &opOperand : llvm::enumerate(producer.getDpsInitOperands())) {
340  if (!preservedProducerResults.count(opOperand.index()))
341  continue;
342 
343  fusedOutputOperands.push_back(opOperand.value()->get());
345  opOperand.value(), producerResultIndexMap,
346  consumer.getMatchingIndexingMap(fusedOperand));
347  fusedIndexMaps.push_back(map);
348  fusedResultTypes.push_back(opOperand.value()->get().getType());
349  }
350 
351  // 7. All of consumer's output operands (skip operands: added by the builder).
352  for (OpOperand *opOperand : consumer.getDpsInitOperands()) {
353  fusedOutputOperands.push_back(opOperand->get());
354  fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
355  Type resultType = opOperand->get().getType();
356  if (!resultType.isa<MemRefType>())
357  fusedResultTypes.push_back(resultType);
358  }
359 
360  // Generate the fused op.
361  auto fusedOp = rewriter.create<GenericOp>(
362  consumer.getLoc(), fusedResultTypes, fusedInputOperands,
363  fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
364  consumer.getIteratorTypes(),
365  /*doc=*/nullptr,
366  /*library_call=*/nullptr);
367  if (!fusedOp.getShapesToLoopsMap()) {
368  // Fused op has invalid indexing maps. Typically this means something is off
369  // in the input, but going ahead here would result in verification errors.
370  // So cleanup and abort.
371  rewriter.eraseOp(fusedOp);
372  return rewriter.notifyMatchFailure(
373  fusedOp, "fused op failed loop bound computation check");
374  }
375 
376  // Construct an AffineMap from consumer loops to producer loops.
377  // consumer loop -> tensor index
378  AffineMap consumerResultIndexMap =
379  consumer.getMatchingIndexingMap(fusedOperand);
380  // tensor index -> producer loop
381  AffineMap invProducerResultIndexMap =
382  inversePermutation(producerResultIndexMap);
383  assert(invProducerResultIndexMap &&
384  "expected producer result indexig map to be invertible");
385  // consumer loop -> producer loop
386  AffineMap consumerToProducerLoopsMap =
387  invProducerResultIndexMap.compose(consumerResultIndexMap);
388 
390  rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
391  consumer.getNumLoops(), preservedProducerResults);
392  return fusedOp.getOperation();
393 }
394 
395 namespace {
396 /// Patterns to fuse a generic op, with the producer of its operands.
397 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
398 public:
399  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
400  PatternBenefit benefit = 1)
401  : OpRewritePattern<GenericOp>(context, benefit),
402  controlFn(std::move(fun)) {}
403 
404  LogicalResult matchAndRewrite(GenericOp genericOp,
405  PatternRewriter &rewriter) const override {
406  // Find the first operand that is defined by another generic op on tensors.
407  for (OpOperand &opOperand : genericOp->getOpOperands()) {
408  if (!areElementwiseOpsFusable(&opOperand))
409  continue;
410  if (!controlFn(&opOperand))
411  continue;
412 
413  FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
414  if (succeeded(fusedOp)) {
415  auto replacements =
416  fusedOp.value()->getResults().take_back(genericOp.getNumResults());
417  rewriter.replaceOp(genericOp, replacements);
418  return success();
419  }
420  }
421  return failure();
422  }
423 
424 private:
425  ControlFusionFn controlFn;
426 };
427 } // namespace
428 
429 //===---------------------------------------------------------------------===//
430 // Methods and patterns that fuse reshape ops with elementwise operations by
431 // expanding the dimensionality of the elementwise operations.
432 //===---------------------------------------------------------------------===//
433 
434 /// Conditions for folding a generic operation with a reshape op by expanding
435 /// the iteration space dimensionality for tensor operations. These are
436 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
437 /// following fusion pattern.
438 ///
439 /// Consider
440 ///
441 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
442 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
443 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
444 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
445 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
446 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
447 ///
448 /// The reshape can be folded into the `genericOp` if its loop dimensionality
449 /// is increased to match the result (operand) of the tensor.expand_shape.
450 /// The indexing_map of the fused tensor in the `genericOp` and the
451 /// reassociation map helps compute the indexing maps of the modified op.
452 /// For the above example, based on the reassociation map it
453 /// can be concluded that
454 ///
455 /// - The loop used to access the first dimension of the fused tensor is split
456 /// into two.
457 /// - The loop used to access the second dimension of the fused tensor is kept
458 /// as is.
459 /// - The loop used to access the third dimension of the fused tensor is split
460 /// into three.
461 ///
462 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
463 /// op, then
464 ///
465 /// d0 -> e0, e1
466 /// d1 -> e2, e3, e4
467 /// d2 -> e5
468 ///
469 /// substituting this, the generic op can be rewritten as
470 ///
471 /// %d = linalg.generic ins(%0, %1 : )
472 /// indexing_maps =
473 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
474 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
475 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
476 ///
477 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
478 /// to make it consistent
479 ///
480 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
481 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
482 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
483 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
484 ///
485 /// The added reshapes are again expanding patterns, so they will get fused
486 /// with its producers if possible.
487 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
488  OpOperand *fusableOpOperand) {
489  // Is fusable only if:
490  // - All the indexing maps for operands and results are projected
491  // permutations.
492  // - The fused tensor is not a scalar.
493  // - All the loops are parallel loops.
494  return genericOp.hasTensorSemantics() &&
495  llvm::all_of(genericOp.getIndexingMaps().getValue(),
496  [](Attribute attr) {
497  return attr.cast<AffineMapAttr>()
498  .getValue()
499  .isProjectedPermutation();
500  }) &&
501  genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
502  0 &&
503  llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
504 }
505 
506 namespace {
507 /// Information needed to expand a generic operation to fold the reshape with
508 /// it.
509 class ExpansionInfo {
510 public:
511  // Computes the mapping from original dimensions of the op to the dimensions
512  // of the expanded op given the `indexingMap` of the fused operand/result of
513  // the generic op, the `reassocationMaps` of the reshape op and the shape of
514  // the expanded op.
515  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
516  ArrayRef<AffineMap> reassociationMaps,
517  ArrayRef<int64_t> expandedShape,
518  ArrayRef<int64_t> collapsedShape,
519  PatternRewriter &rewriter);
520  unsigned getOrigOpNumDims() const { return reassociation.size(); }
521  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
522  ReassociationIndicesRef getExpandedDims(unsigned i) const {
523  return reassociation[i];
524  }
525  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
526  return expandedShapeMap[i];
527  }
528  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
529 
530 private:
531  /// Reassociation from the dimensions in the original operation to the
532  /// dimension of the expanded operation.
533  SmallVector<ReassociationIndices> reassociation;
534  /// Mapping from extent of loops in the original operation, to the extent of
535  /// loops in the expanded operation.
536  SmallVector<SmallVector<int64_t>> expandedShapeMap;
537  /// Extent of the loop in the original operation.
538  SmallVector<int64_t> originalLoopExtent;
539  unsigned expandedOpNumDims;
540 };
541 } // namespace
542 
543 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
544  OpOperand *fusableOpOperand,
545  ArrayRef<AffineMap> reassociationMaps,
546  ArrayRef<int64_t> expandedShape,
547  ArrayRef<int64_t> collapsedShape,
548  PatternRewriter &rewriter) {
549  if (reassociationMaps.empty())
550  return failure();
551  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
552 
553  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
554  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
555 
556  reassociation.clear();
557  expandedShapeMap.clear();
558  // Compute the number of dimension in the expanded op that correspond to each
559  // dimension of the original op.
560  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
561  expandedShapeMap.resize(fusedIndexMap.getNumDims());
562  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
563  unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
564  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
565  numExpandedDims[pos] = foldedDims.getNumResults();
566  ArrayRef<int64_t> shape =
567  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
568  expandedShapeMap[pos].assign(shape.begin(), shape.end());
569  }
570  // The remaining dimensions remain the same.
571  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
572  if (expandedShapeMap[i].empty())
573  expandedShapeMap[i] = {originalLoopExtent[i]};
574 
575  // Compute reassociation map from the original op to the expanded op.
576  unsigned sum = 0;
577  reassociation.reserve(fusedIndexMap.getNumDims());
578  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
579  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
580  reassociation.emplace_back(seq.begin(), seq.end());
581  sum += numFoldedDim.value();
582  }
583  expandedOpNumDims = sum;
584  return success();
585 }
586 
587 /// Epanding the body of a linalg operation requires adaptations of the accessed
588 /// loop indices. Specifically, access of indices in the original operation need
589 /// to be replaced with linearizations of indices in the expanded op. That
590 /// requires the shape of the expanded dimensions to be static (at least all but
591 /// the most significant). For now check that these are all statically sized.
592 /// Note that this could be extended to handle dynamic case, but the
593 /// implementation below uses `affine.apply` which seems to have issues when the
594 /// shapes are not static.
595 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
596  const ExpansionInfo &expansionInfo,
597  PatternRewriter &rewriter) {
598  if (!genericOp.hasIndexSemantics())
599  return success();
600  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
601  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
602  if (expandedShape.size() == 1)
603  continue;
604  for (int64_t shape : expandedShape.drop_front()) {
605  if (ShapedType::isDynamic(shape)) {
606  return rewriter.notifyMatchFailure(
607  genericOp, "cannot expand due to index semantics and dynamic dims");
608  }
609  }
610  }
611  return success();
612 }
613 
614 /// Return the indexing map to use in the expanded op for a given the
615 /// `indexingMap` of the original operation.
616 static AffineMap
618  const ExpansionInfo &expansionInfo) {
619  SmallVector<AffineExpr> newExprs;
620  for (AffineExpr expr : indexingMap.getResults()) {
621  unsigned pos = expr.cast<AffineDimExpr>().getPosition();
622  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
623  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
624  return builder.getAffineDimExpr(static_cast<unsigned>(v));
625  }));
626  newExprs.append(expandedExprs.begin(), expandedExprs.end());
627  }
628  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
629  indexingMap.getNumSymbols(), newExprs,
630  builder.getContext());
631 }
632 
633 /// Return the type of the operand/result to use in the expanded op given the
634 /// type in the original op.
635 static RankedTensorType getExpandedType(RankedTensorType originalType,
636  AffineMap indexingMap,
637  const ExpansionInfo &expansionInfo) {
638  SmallVector<int64_t> expandedShape;
639  for (AffineExpr expr : indexingMap.getResults()) {
640  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
641  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
642  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
643  }
644  return RankedTensorType::get(expandedShape, originalType.getElementType());
645 }
646 
647 /// Returns the reassociation maps to use in the `tensor.expand_shape`
648 /// operation to convert the operands of the original operation to operands of
649 /// the expanded operation. The same method is used to compute the
650 /// `tensor.collapse_shape` used to collapse the result of the expanded
651 /// op to get the value that can replace all uses of the results of the original
652 /// op.
655  const ExpansionInfo &expansionInfo) {
656  SmallVector<ReassociationIndices> reassociation;
657  unsigned numReshapeDims = 0;
658  for (AffineExpr expr : indexingMap.getResults()) {
659  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
660  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
661  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
662  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
663  reassociation.emplace_back(std::move(indices));
664  numReshapeDims += numExpandedDims;
665  }
666  return reassociation;
667 }
668 
669 /// Update the body of an expanded linalg operation having index semantics. The
670 /// indices of the original operation need to be recovered by linearizing the
671 /// indices of the correspoding dimensions of the expanded operation. For now it
672 /// is assumed that the shapes of the expanded operation needed for
673 /// linearization are static.
675  Location loc, Region &fusedRegion,
676  const ExpansionInfo &expansionInfo) {
677  // Replace the original indices by the linearization of the expanded indices.
678  for (IndexOp indexOp :
679  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
680  ArrayRef<int64_t> expandedDims =
681  expansionInfo.getExpandedDims(indexOp.getDim());
682  assert(!expandedDims.empty() && "expected valid expansion info");
683 
684  // Skip index operations that are not affected by the expansion.
685  if (expandedDims.size() == 1 &&
686  expandedDims.front() == (int64_t)indexOp.getDim())
687  continue;
688 
689  // Linearize the expanded indices of the original index dimension.
690  OpBuilder::InsertionGuard guard(rewriter);
691  rewriter.setInsertionPointAfter(indexOp);
692  ArrayRef<int64_t> expandedDimsShape =
693  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
694  SmallVector<Value> expandedIndices;
695  expandedIndices.reserve(expandedDims.size() - 1);
696  llvm::transform(
697  expandedDims.drop_front(), std::back_inserter(expandedIndices),
698  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
699  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
700  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
701  assert(!ShapedType::isDynamic(std::get<0>(it)));
702  AffineExpr idx, acc;
703  bindDims(rewriter.getContext(), idx, acc);
704  newIndex = rewriter.create<AffineApplyOp>(
705  indexOp.getLoc(), idx + acc * std::get<0>(it),
706  ValueRange{std::get<1>(it), newIndex});
707  }
708  rewriter.replaceOp(indexOp, newIndex);
709  }
710 }
711 
712 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
713 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
714 /// that those conditions have been satisfied.
716 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
717  OpOperand *fusableOpOperand,
718  PatternRewriter &rewriter) {
719  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
720  "preconditions for fuse operation failed");
721  // Check if reshape is expanding or collapsing.
722  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
723  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
724  bool isExpanding = (expandingReshapeOp != nullptr);
725  RankedTensorType expandedType = isExpanding
726  ? expandingReshapeOp.getResultType()
727  : collapsingReshapeOp.getSrcType();
728  RankedTensorType collapsedType = isExpanding
729  ? expandingReshapeOp.getSrcType()
730  : collapsingReshapeOp.getResultType();
731 
732  ExpansionInfo expansionInfo;
733  if (failed(expansionInfo.compute(
734  genericOp, fusableOpOperand,
735  isExpanding ? expandingReshapeOp.getReassociationMaps()
736  : collapsingReshapeOp.getReassociationMaps(),
737  expandedType.getShape(), collapsedType.getShape(), rewriter)))
738  return std::nullopt;
739 
740  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
741  return std::nullopt;
742 
743  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
744  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
745  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
746  }));
747 
748  // Set insertion point to the generic op.
749  OpBuilder::InsertionGuard g(rewriter);
750  rewriter.setInsertionPoint(genericOp);
751 
752  SmallVector<Value> expandedOpOperands;
753  expandedOpOperands.reserve(genericOp.getNumDpsInputs());
754  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
755  if (opOperand == fusableOpOperand) {
756  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
757  : collapsingReshapeOp.getSrc());
758  continue;
759  }
760  if (auto opOperandType =
761  opOperand->get().getType().dyn_cast<RankedTensorType>()) {
762  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
763  RankedTensorType expandedOperandType =
764  getExpandedType(opOperandType, indexingMap, expansionInfo);
765  if (expandedOperandType != opOperand->get().getType()) {
766  // Reshape the operand to get the right type.
767  SmallVector<ReassociationIndices> reassociation =
768  getReassociationForExpansion(indexingMap, expansionInfo);
770  [&](const Twine &msg) {
771  return rewriter.notifyMatchFailure(genericOp, msg);
772  },
773  opOperandType.getShape(), expandedOperandType.getShape(),
774  reassociation,
775  /*isExpandingReshape=*/true)))
776  return std::nullopt;
777  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
778  genericOp.getLoc(), expandedOperandType, opOperand->get(),
779  reassociation));
780  continue;
781  }
782  }
783  expandedOpOperands.push_back(opOperand->get());
784  }
785 
786  Location loc = genericOp.getLoc();
787  SmallVector<Value> outputs;
788  for (OpOperand *opOperand : genericOp.getDpsInitOperands()) {
789  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
790  auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
791  RankedTensorType expandedOutputType =
792  getExpandedType(opOperandType, indexingMap, expansionInfo);
793  if (expandedOutputType != opOperand->get().getType()) {
794  SmallVector<ReassociationIndices> reassociation =
795  getReassociationForExpansion(indexingMap, expansionInfo);
797  [&](const Twine &msg) {
798  return rewriter.notifyMatchFailure(genericOp, msg);
799  },
800  opOperandType.getShape(), expandedOutputType.getShape(),
801  reassociation,
802  /*isExpandingReshape=*/true)))
803  return std::nullopt;
804  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
805  genericOp.getLoc(), expandedOutputType, opOperand->get(),
806  reassociation));
807  } else {
808  outputs.push_back(opOperand->get());
809  }
810  }
811 
812  // The iterator types of the expanded op are all parallel.
813  SmallVector<utils::IteratorType> iteratorTypes(
814  expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
815 
816  TypeRange resultTypes = ValueRange(outputs).getTypes();
817  auto fusedOp =
818  rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
819  /*inputs=*/expandedOpOperands, outputs,
820  expandedOpIndexingMaps, iteratorTypes);
821  Region &fusedRegion = fusedOp->getRegion(0);
822  Region &originalRegion = genericOp->getRegion(0);
823  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824 
825  // Update the index accesses after the expansion.
826  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
827 
828  // Reshape the result values to their original shape if this is a collapsing
829  // reshape folded into its consumer.
830  SmallVector<Value> resultVals;
831  for (OpResult opResult : genericOp->getOpResults()) {
832  int64_t resultNumber = opResult.getResultNumber();
833  if (resultTypes[resultNumber] != opResult.getType()) {
834  SmallVector<ReassociationIndices> reassociation =
836  genericOp.getMatchingIndexingMap(
837  genericOp.getDpsInitOperand(resultNumber)),
838  expansionInfo);
839  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
840  genericOp.getLoc(), opResult.getType(),
841  fusedOp->getResult(resultNumber), reassociation));
842  } else {
843  resultVals.push_back(fusedOp->getResult(resultNumber));
844  }
845  }
846  // Assuming a single result.
847  return resultVals;
848 }
849 
850 namespace {
851 
852 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
853 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
854 /// in the consumer is expanded.
855 class FoldWithProducerReshapeOpByExpansion
856  : public OpRewritePattern<GenericOp> {
857 public:
858  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
859  ControlFusionFn foldReshapes,
860  PatternBenefit benefit = 1)
861  : OpRewritePattern<GenericOp>(context, benefit),
862  controlFoldingReshapes(std::move(foldReshapes)) {}
863 
864  LogicalResult matchAndRewrite(GenericOp genericOp,
865  PatternRewriter &rewriter) const override {
866  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
867  tensor::CollapseShapeOp reshapeOp =
868  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
869  if (!reshapeOp)
870  continue;
871  // Fold only if
872  // - The tensor reshape op is folding.
873  // - All constraints of fusing with reshape by expansion are met.
874  if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
875  (!controlFoldingReshapes(opOperand)))
876  continue;
877 
878  Optional<SmallVector<Value>> replacementValues =
879  fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
880  if (!replacementValues)
881  return failure();
882  rewriter.replaceOp(genericOp, *replacementValues);
883  return success();
884  }
885  return failure();
886  }
887 
888 private:
889  ControlFusionFn controlFoldingReshapes;
890 };
891 
892 /// Pattern to fold a tensor.expand_shape op with its producer generic op
893 /// by expanding the dimensionality of the loop in the producer op.
894 struct FoldReshapeWithGenericOpByExpansion
895  : public OpRewritePattern<tensor::ExpandShapeOp> {
896 
897  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
898  ControlFusionFn foldReshapes,
899  PatternBenefit benefit = 1)
900  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
901  controlFoldingReshapes(std::move(foldReshapes)) {}
902 
903  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
904  PatternRewriter &rewriter) const override {
905  // Fold only if all constraints of fusing with reshape by expansion are met.
906  auto producerResult = reshapeOp.getSrc().dyn_cast<OpResult>();
907  if (!producerResult) {
908  return rewriter.notifyMatchFailure(reshapeOp,
909  "source not produced by an operation");
910  }
911 
912  auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
913  if (!producer) {
914  return rewriter.notifyMatchFailure(reshapeOp,
915  "producer not a generic op");
916  }
917 
919  producer,
920  producer.getDpsInitOperand(producerResult.getResultNumber()))) {
921  return rewriter.notifyMatchFailure(
922  reshapeOp, "failed preconditions of fusion with producer generic op");
923  }
924 
925  if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
926  return rewriter.notifyMatchFailure(reshapeOp,
927  "fusion blocked by control function");
928  }
929 
931  producer, reshapeOp,
932  producer.getDpsInitOperand(producerResult.getResultNumber()), rewriter);
933  if (!replacementValues) {
934  return rewriter.notifyMatchFailure(reshapeOp,
935  "fusion by expansion failed");
936  }
937 
938  // Find the replacement for the reshape op. Since the replacements have the
939  // same type as the returns of the original generic op, the consumer reshape
940  // op can be replaced by the source of the collapse_shape op that defines
941  // the replacement.
942  Value reshapeReplacement = (*replacementValues)
943  [reshapeOp.getSrc().cast<OpResult>().getResultNumber()];
944  if (auto collapseOp =
945  reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
946  reshapeReplacement = collapseOp.getSrc();
947  }
948  rewriter.replaceOp(reshapeOp, reshapeReplacement);
949  rewriter.replaceOp(producer, *replacementValues);
950  return success();
951  }
952 
953 private:
954  ControlFusionFn controlFoldingReshapes;
955 };
956 } // namespace
957 
958 //===---------------------------------------------------------------------===//
959 // Methods and patterns to fuse reshape with linalg.generic operations by
960 // contraction of dimensions.
961 //===---------------------------------------------------------------------===//
962 
963 /// For a given list of indices in the range of the `indexingMap` that are
964 /// folded, return the indices of the corresponding domain. Return
965 /// `std::nullopt` on failure. Ensures that all the elements of the returned
966 /// reassociation are distinct.
969  ReassociationIndicesRef rangeReassociation) {
970  assert(indexingMap.isProjectedPermutation() &&
971  "expected projected permutation");
972 
973  ReassociationIndices domainReassociation = llvm::to_vector<4>(
974  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
975  return indexingMap.getResults()[pos]
976  .cast<AffineDimExpr>()
977  .getPosition();
978  }));
979  // The projected permutation semantics ensures that there is no repetition of
980  // the domain indices.
981  return domainReassociation;
982 }
983 
984 /// For a given `dimSequence`, check if the sequence is conserved in the
985 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
986 /// Non-existence of the sequence returns true as well.
987 static bool isDimSequencePreserved(AffineMap indexingMap,
988  ReassociationIndicesRef dimSequence) {
989  assert(!dimSequence.empty() &&
990  "expected non-empty list for dimension sequence");
991  assert(indexingMap.isProjectedPermutation() &&
992  "expected indexing map to be projected permutation");
993 
994  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
995  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
996 
997  unsigned dimSequenceStart = dimSequence[0];
998  for (const auto &expr : enumerate(indexingMap.getResults())) {
999  unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
1000  // 1. Check if this start of the sequence.
1001  if (dimInMapStart == dimSequenceStart) {
1002  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1003  return false;
1004  // 1a. Check if sequence is preserved.
1005  for (const auto &dimInSequence : enumerate(dimSequence)) {
1006  unsigned dimInMap =
1007  indexingMap.getResult(expr.index() + dimInSequence.index())
1008  .cast<AffineDimExpr>()
1009  .getPosition();
1010  if (dimInMap != dimInSequence.value())
1011  return false;
1012  }
1013  // Found the sequence. Projected permutation
1014  // enforces that all AffineDimExprs in the result are unique, so no
1015  // further checks are needed.
1016  return true;
1017  }
1018  // 2. If position in the expr (which is of type AffineDimExpr) is part
1019  // of sequence, return false here. This implies the entire sequence does not
1020  // exist in the indexing map.
1021  if (sequenceElements.count(dimInMapStart))
1022  return false;
1023  }
1024  // 3. No element of sequence found. Return true.
1025  return true;
1026 }
1027 
1028 // Return the list of dimensions of the iteration domain that can be
1029 // collapsed to allow for fusion with the a producer that is an expand_shape
1030 // operation. If all dimensions created by expansion can be collapsed in the
1031 // iteration space then the reshape is defunct.
1032 //
1033 // Example:
1034 //
1035 // ```mlir
1036 // #map = affine_map<(d0, d1) -> (d0, d1)>
1037 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1038 // %2 = tensor.empty [..] : tensor<?x4xf32>
1039 // %3 = linalg.generic {
1040 // indexing_maps = [#map, #map],
1041 // iterator_types = ["parallel" ,"parallel"]}
1042 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1043 // ```
1044 //
1045 // can be fused by collapsing the dimensions of the iteration space.
1046 //
1047 // ```mlir
1048 // #map = affine_map<(d0) -> (d0)>
1049 // %2 = tensor.empty [..] : tensor<?xf32>
1050 // %3 = linalg.generic {
1051 // indexing_maps = [#map, #map],
1052 // iterator_types = ["parallel"]}
1053 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1054 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1055 // ```
1056 //
1057 // In the following example,
1058 //
1059 // ```mlir
1060 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1061 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1062 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1063 // %2 = tensor.empty [..] : tensor<4x?xf32>
1064 // %2 = linalg.generic {
1065 // indexing_maps = [#map0, #map1],
1066 // iterator_types = ["parallel" ,"parallel"]}
1067 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1068 // ```
1069 //
1070 // the reshape cannot be fused with the generic op by collapsing the op
1071 // dimensions since the indexing maps will have to contain mods and divs
1072 // to preserve the accesses pattern. When no dimensions of the iteration
1073 // space are collapsable and empty vector is returned.
1075 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1076  ArrayRef<ReassociationIndices> reassociation) {
1077  // Some basic checks for this fusion to be valid.
1078  if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
1079  return {};
1080 
1081  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1082  return map.isProjectedPermutation();
1083  })) {
1084  return {};
1085  }
1086 
1087  // Compute all the loops with the reduction iterator types.
1088  SmallVector<unsigned> reductionDims;
1089  genericOp.getReductionDims(reductionDims);
1090 
1091  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1092  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1093  auto iteratorTypes = genericOp.getIteratorTypesArray();
1094  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1095  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1096  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1097 
1098  // Ignore dims that are not folded.
1099  if (foldedRangeDims.size() == 1)
1100  continue;
1101 
1102  ReassociationIndices foldedIterationSpaceDims =
1103  getDomainReassociation(indexingMap, foldedRangeDims);
1104 
1105  // Check that the folded iteration dims do not contain already processed
1106  // dims.
1107  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1108  return processedIterationDims.count(dim);
1109  }))
1110  continue;
1111 
1112  // Check that all folded iterator types are all parallel or all reductions.
1113  utils::IteratorType startIteratorType =
1114  iteratorTypes[foldedIterationSpaceDims[0]];
1115  if (!isParallelIterator(startIteratorType) &&
1116  !isReductionIterator(startIteratorType))
1117  continue;
1118  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1119  return iteratorTypes[dim] != startIteratorType;
1120  }))
1121  continue;
1122 
1123  // If the folded dimensions correspond to a "reduction" iterator type,
1124  // the folded dimensions need to be "in-order". Strictly speaking this is
1125  // not necessary, for reductions that are associative and commutative, but
1126  // using a more strict definition of reduction for now.
1127  if (isReductionIterator(startIteratorType)) {
1128  bool isContiguous = false;
1129  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1130  // Move window in `reductionDims` to start of the folded iteration dims.
1131  if (startDim.value() != foldedIterationSpaceDims[0])
1132  continue;
1133  // If sizes doesnt match, trivial not contiguous. This condition should
1134  // not be hit.
1135  if (startDim.index() + foldedIterationSpaceDims.size() >
1136  reductionDims.size())
1137  break;
1138  // Check that the contiguity is maintained.
1139  isContiguous = true;
1140  for (const auto &foldedDim :
1141  llvm::enumerate(foldedIterationSpaceDims)) {
1142  if (reductionDims[foldedDim.index() + startDim.index()] !=
1143  foldedDim.value()) {
1144  isContiguous = false;
1145  break;
1146  }
1147  }
1148  break;
1149  }
1150  if (!isContiguous)
1151  continue;
1152  }
1153 
1154  // Check that the sequence is preserved in all indexing maps.
1155  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1156  [&](AffineMap indexingMap) {
1157  return !isDimSequencePreserved(indexingMap,
1158  foldedIterationSpaceDims);
1159  }))
1160  continue;
1161 
1162  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1163  foldedIterationSpaceDims.end());
1164  iterationSpaceReassociation.emplace_back(
1165  std::move(foldedIterationSpaceDims));
1166  }
1167 
1168  return iterationSpaceReassociation;
1169 }
1170 
1171 /// Helper class to carry state while collapsing the `linalg.generic` op.
1172 namespace {
1173 class CollapsingInfo {
1174 public:
1175  LogicalResult initialize(unsigned origNumLoops,
1176  ArrayRef<ReassociationIndices> foldedIterationDims) {
1177  llvm::SmallDenseSet<int64_t, 4> processedDims;
1178  // Find all the dims that are folded.
1179  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1180  if (foldedIterationDim.empty())
1181  continue;
1182  // If the folded dims contain dims already folded, that's illegal
1183  // specification. Repetition within a list is also illegal.
1184  for (auto dim : foldedIterationDim) {
1185  if (dim >= origNumLoops)
1186  return failure();
1187  if (processedDims.count(dim))
1188  return failure();
1189  processedDims.insert(dim);
1190  }
1191  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1192  foldedIterationDim.end());
1193  }
1194  if (processedDims.size() > origNumLoops)
1195  return failure();
1196 
1197  // Add all the preserved dims of the original op as single
1198  // elements to `collapsedOpToOrigOpIterationDim`.
1199  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1200  if (processedDims.count(dim))
1201  continue;
1202  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1203  }
1204 
1205  llvm::sort(collapsedOpToOrigOpIterationDim,
1207  return lhs[0] < rhs[0];
1208  });
1209  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1210  for (const auto &foldedDims :
1211  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1212  for (const auto &dim : enumerate(foldedDims.value()))
1213  origOpToCollapsedOpIterationDim[dim.value()] =
1214  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1215  }
1216  return success();
1217  }
1218 
1219  /// Return mapping from collapsed loop domain to original loop domain.
1220  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1221  return collapsedOpToOrigOpIterationDim;
1222  }
1223 
1224  /// Return mapping from original loop domain to collapsed loop domain. The
1225  /// mapping is a pair. First value is the dimension in the collapsed loop that
1226  /// the original loop is mapped to. Second is the relative position in folded
1227  /// list of this domain. For example if the original loop domain is 3D, and
1228  /// the collapsed loop domain is folding all of it, i.e.
1229  ///
1230  /// ```
1231  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1232  /// ```
1233  ///
1234  /// then
1235  ///
1236  /// ```
1237  /// origOpToCollapsedOpMapping[0] = {0, 0};
1238  /// origOpToCollapsedOpMapping[1] = {0, 1};
1239  /// origOpToCollapsedOpMapping[2] = {0, 2};
1240  /// origOpToCollapsedOpMapping[3] = {1, 0};
1241  /// origOpToCollapsedOpMapping[4] = {1, 1};
1242  /// ```
1243  ///
1244  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1245  return origOpToCollapsedOpIterationDim;
1246  }
1247 
1248  /// Return the collapsed op iteration domain rank.
1249  unsigned getCollapsedOpIterationRank() const {
1250  return collapsedOpToOrigOpIterationDim.size();
1251  }
1252 
1253 private:
1254  /// Map from the iteration domain index in collapsed op to the iteration
1255  /// domain indices in the original op.
1256  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1257 
1258  /// Map from iteration domain index in the original op to the iteration domain
1259  /// index in the collapsed op.
1260  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1261 };
1262 } // namespace
1263 
1264 /// Get the iterator types for the collapsed operation given the original
1265 /// iterator types and collapsed dimensions.
1268  const CollapsingInfo &collapsingInfo) {
1269  SmallVector<utils::IteratorType> collapsedIteratorTypes;
1270  for (ReassociationIndicesRef foldedIterDims :
1271  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1272  assert(!foldedIterDims.empty() &&
1273  "reassociation indices expected to have non-empty sets");
1274  // Just pick the iterator type of the first folded dim. Pre-condition checks
1275  // expected to have checked that iterator types of all folded dimensions are
1276  // the same.
1277  collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1278  }
1279  return collapsedIteratorTypes;
1280 }
1281 
1282 /// Compute the indexing map in the collapsed op that corresponds to the given
1283 /// `indexingMap` of the original operation.
1284 static AffineMap
1286  const CollapsingInfo &collapsingInfo) {
1287  MLIRContext *context = indexingMap.getContext();
1288  assert(indexingMap.isProjectedPermutation() &&
1289  "expected indexing map to be projected permutation");
1290  SmallVector<AffineExpr> resultExprs;
1291  auto origOpToCollapsedOpMapping =
1292  collapsingInfo.getOrigOpToCollapsedOpMapping();
1293  for (auto expr : indexingMap.getResults()) {
1294  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1295  // If the dim is not the first of the collapsed dim, do nothing.
1296  if (origOpToCollapsedOpMapping[dim].second != 0)
1297  continue;
1298  // The next n-dims are guaranteed to be collapsed. So just use the
1299  // iteration dimension of the collapsed op.
1300  resultExprs.push_back(
1301  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1302  }
1303  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1304  resultExprs, context);
1305 }
1306 
1307 /// Return the `reassociation` indices to use to collapse the operand when the
1308 /// iteration space of a generic op is collapsed.
1311  const CollapsingInfo &collapsingInfo) {
1312  unsigned counter = 0;
1313  SmallVector<ReassociationIndices> operandReassociation;
1314  auto origOpToCollapsedOpMapping =
1315  collapsingInfo.getOrigOpToCollapsedOpMapping();
1316  auto collapsedOpToOrigOpMapping =
1317  collapsingInfo.getCollapsedOpToOrigOpMapping();
1318  while (counter < indexingMap.getNumResults()) {
1319  unsigned dim =
1320  indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1321  // This is the start of a collapsed dimensions of the iteration that
1322  // is gauranteed to be preserved in the indexing map. The number of folded
1323  // dims is obtained from the collapsed op to original op mapping.
1324  unsigned numFoldedDims =
1325  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1326  .size();
1327  if (origOpToCollapsedOpMapping[dim].second == 0) {
1328  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1329  operandReassociation.emplace_back(range.begin(), range.end());
1330  }
1331  counter += numFoldedDims;
1332  }
1333  return operandReassociation;
1334 }
1335 
1336 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1337 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1338  OpOperand *opOperand,
1339  const CollapsingInfo &collapsingInfo,
1340  OpBuilder &builder) {
1341  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
1342  SmallVector<ReassociationIndices> operandReassociation =
1343  getOperandReassociation(indexingMap, collapsingInfo);
1344 
1345  // If the number of entries in the reassocation for the operand is same as the
1346  // number of results of the indexing map, then nothing to do for this operand.
1347  Value operand = opOperand->get();
1348  if (operandReassociation.size() == indexingMap.getNumResults())
1349  return operand;
1350 
1351  // Insert a reshape to collapse the dimensions.
1352  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1353  loc, operand, operandReassociation);
1354  return reshapeOp.getResult();
1355 }
1356 
1357 /// Modify the `linalg.index` operations in the original generic op, to its
1358 /// value in the collapsed operation.
1360  const CollapsingInfo &collapsingInfo,
1361  ValueRange loopRange,
1362  RewriterBase &rewriter) {
1363  OpBuilder::InsertionGuard g(rewriter);
1364  rewriter.setInsertionPointToStart(block);
1365 
1366  // Collect all the original index ops.
1367  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1368 
1369  // For each folded dimension list resolve the original induction variable
1370  // values in terms of the folded dimension induction variable.
1371  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1372  // can be inverted to
1373  // i2 = i_{folded} % d2
1374  // i1 = (i_{folded} / d2) % d1
1375  // i0 = i_{folded} / (d1 * d2)
1376  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1377  for (auto &foldedDims :
1378  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1379  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1380  Value newIndexVal =
1381  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1382  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1383  indexReplacementVals[dim] =
1384  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1385  newIndexVal =
1386  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1387  }
1388  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1389  }
1390 
1391  for (auto indexOp : indexOps) {
1392  auto dim = indexOp.getDim();
1393  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1394  }
1395 }
1396 
1397 /// Implementation of fusion with reshape operation by collapsing dimensions.
1399  GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1400  RewriterBase &rewriter) {
1401  // Bail on trivial no-op cases.
1402  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1403  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1404  return foldedDims.size() <= 1;
1405  }))
1406  return failure();
1407 
1408  CollapsingInfo collapsingInfo;
1409  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1410  foldedIterationDims))) {
1411  return rewriter.notifyMatchFailure(
1412  genericOp, "illegal to collapse specified dimensions");
1413  }
1414 
1415  // Bail on non-canonical ranges.
1416  SmallVector<Range> loopRanges =
1417  cast<LinalgOp>(genericOp.getOperation())
1418  .createLoopRanges(rewriter, genericOp.getLoc());
1419  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1420  if (auto attr = ofr.dyn_cast<Attribute>())
1421  return attr.cast<IntegerAttr>().getInt() == value;
1422  llvm::APInt actual;
1423  return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1424  actual.getSExtValue() == value;
1425  };
1426  if (!llvm::all_of(loopRanges, [&](Range range) {
1427  return opFoldIsConstantValue(range.offset, 0) &&
1428  opFoldIsConstantValue(range.stride, 1);
1429  })) {
1430  return rewriter.notifyMatchFailure(
1431  genericOp,
1432  "expected all loop ranges to have zero start and unit stride");
1433  }
1434 
1435  // Get the iterator types for the operand.
1437  genericOp.getIteratorTypesArray(), collapsingInfo);
1438 
1439  // Get the indexing maps.
1440  auto indexingMaps = llvm::to_vector(
1441  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
1442  return getCollapsedOpIndexingMap(map, collapsingInfo);
1443  }));
1444 
1445  Location loc = genericOp->getLoc();
1446 
1447  // Get the input operands.
1448  auto inputOperands = llvm::to_vector(llvm::map_range(
1449  genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
1450  return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1451  rewriter);
1452  }));
1453 
1454  // Get the output operands and result types.
1455  SmallVector<Type> resultTypes;
1456  SmallVector<Value> outputOperands;
1457  resultTypes.reserve(genericOp.getNumDpsInits());
1458  outputOperands.reserve(genericOp.getNumDpsInits());
1459  for (OpOperand *output : genericOp.getDpsInitOperands()) {
1460  Value newOutput =
1461  getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1462  outputOperands.push_back(newOutput);
1463  resultTypes.push_back(newOutput.getType());
1464  }
1465 
1466  // Create the generic op.
1467  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1468  loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1469  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1470  Block *origOpBlock = &genericOp->getRegion(0).front();
1471  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1472  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1473  collapsedOpBlock->getArguments());
1474 
1475  if (collapsedGenericOp.hasIndexSemantics()) {
1476  // Collect the loop range of the generic op.
1477  OpBuilder::InsertionGuard g(rewriter);
1478  rewriter.setInsertionPoint(collapsedGenericOp);
1479  SmallVector<Value> loopBound =
1480  llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
1481  return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1482  }));
1484  &collapsedGenericOp->getRegion(0).front(),
1485  collapsingInfo, loopBound, rewriter);
1486  }
1487 
1488  // Insert expanding reshape for the result to get back the original result
1489  // type.
1490  SmallVector<Value> results;
1491  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1492  Value collapsedOpResult =
1493  collapsedGenericOp->getResult(originalResult.index());
1494  auto originalResultType =
1495  originalResult.value().getType().cast<ShapedType>();
1496  auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1497  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1498  AffineMap indexingMap =
1499  genericOp.getIndexingMapMatchingResult(originalResult.value());
1500  SmallVector<ReassociationIndices> reassociation =
1501  getOperandReassociation(indexingMap, collapsingInfo);
1502  Value result = rewriter.create<tensor::ExpandShapeOp>(
1503  loc, originalResultType, collapsedOpResult, reassociation);
1504  results.push_back(result);
1505  } else {
1506  results.push_back(collapsedOpResult);
1507  }
1508  }
1509  return results;
1510 }
1511 
1512 namespace {
1513 
1514 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1515 /// contracting dimensions of the loop.
1516 class FoldWithProducerReshapeOpByCollapsing
1517  : public OpRewritePattern<GenericOp> {
1518 public:
1519  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1520  ControlFusionFn foldReshapes,
1521  PatternBenefit benefit = 1)
1522  : OpRewritePattern<GenericOp>(context, benefit),
1523  controlFoldingReshapes(std::move(foldReshapes)) {}
1524 
1525  LogicalResult matchAndRewrite(GenericOp genericOp,
1526  PatternRewriter &rewriter) const override {
1527  for (OpOperand &opOperand : genericOp->getOpOperands()) {
1528  tensor::ExpandShapeOp reshapeOp =
1529  opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1530  if (!reshapeOp)
1531  continue;
1532 
1533  SmallVector<ReassociationIndices> collapsableIterationDims =
1534  getCollapsableIterationSpaceDims(genericOp, &opOperand,
1535  reshapeOp.getReassociationIndices());
1536  if (collapsableIterationDims.empty() ||
1537  !controlFoldingReshapes(&opOperand)) {
1538  continue;
1539  }
1540 
1541  Optional<SmallVector<Value>> replacements =
1542  collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1543  rewriter);
1544  if (!replacements) {
1545  return rewriter.notifyMatchFailure(
1546  genericOp, "failed to do the fusion by collapsing transformation");
1547  }
1548 
1549  rewriter.replaceOp(genericOp, *replacements);
1550  return success();
1551  }
1552  return failure();
1553  }
1554 
1555 private:
1556  ControlFusionFn controlFoldingReshapes;
1557 };
1558 
1559 /// Pattern to collapse dimensions.
1560 class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
1561 public:
1562  CollapseLinalgDimensions(MLIRContext *context,
1563  GetCollapsableDimensionsFn collapseDimensions,
1564  PatternBenefit benefit = 1)
1565  : OpRewritePattern<GenericOp>(context, benefit),
1566  controlCollapseDimension(std::move(collapseDimensions)) {}
1567 
1568  LogicalResult matchAndRewrite(GenericOp genericOp,
1569  PatternRewriter &rewriter) const override {
1570  SmallVector<ReassociationIndices> collapsableIterationDims =
1571  controlCollapseDimension(genericOp);
1572  if (collapsableIterationDims.empty())
1573  return failure();
1574 
1576  genericOp, collapsableIterationDims, rewriter);
1577  if (!replacements) {
1578  return rewriter.notifyMatchFailure(genericOp,
1579  "failed to collapse dimensions");
1580  }
1581  rewriter.replaceOp(genericOp, *replacements);
1582  return success();
1583  }
1584 
1585 private:
1586  GetCollapsableDimensionsFn controlCollapseDimension;
1587 };
1588 
1589 } // namespace
1590 
1591 //===---------------------------------------------------------------------===//
1592 // Methods and patterns that fuse constants with linalg.generic operations.
1593 //===---------------------------------------------------------------------===//
1594 
1595 namespace {
1596 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1597 /// handle cases where the constant is not single-valued.
1598 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1599 public:
1600  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1601  : OpRewritePattern<GenericOp>(context, benefit) {}
1602 
1603  LogicalResult matchAndRewrite(GenericOp genericOp,
1604  PatternRewriter &rewriter) const override {
1605  if (!genericOp.hasTensorSemantics())
1606  return failure();
1607  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1608  Operation *def = opOperand->get().getDefiningOp();
1609  TypedAttr constantAttr;
1610  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1611  {
1612  DenseElementsAttr splatAttr;
1613  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1614  splatAttr.isSplat() &&
1615  splatAttr.getType().getElementType().isIntOrFloat()) {
1616  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1617  return true;
1618  }
1619  }
1620  {
1621  IntegerAttr intAttr;
1622  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1623  constantAttr = intAttr;
1624  return true;
1625  }
1626  }
1627  {
1628  FloatAttr floatAttr;
1629  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1630  constantAttr = floatAttr;
1631  return true;
1632  }
1633  }
1634  return false;
1635  };
1636 
1637  auto resultValue = opOperand->get().dyn_cast<OpResult>();
1638  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1639  continue;
1640 
1641  // The operands and the indexing_maps of the fused operation the same as
1642  // the operands and indexing_maps of the generic operations with the
1643  // values at the constant index dropped.
1644  SmallVector<AffineMap> fusedIndexMaps;
1645  SmallVector<Value> fusedOperands;
1646  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1647  fusedIndexMaps.reserve(genericOp->getNumOperands());
1648  fusedOperands.reserve(genericOp.getNumDpsInputs());
1649  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1650  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1651  if (inputOperand == opOperand)
1652  continue;
1653  Value inputValue = inputOperand->get();
1654  fusedIndexMaps.push_back(
1655  genericOp.getMatchingIndexingMap(inputOperand));
1656  fusedOperands.push_back(inputValue);
1657  fusedLocs.push_back(inputValue.getLoc());
1658  }
1659  for (OpOperand *outputOperand : genericOp.getDpsInitOperands())
1660  fusedIndexMaps.push_back(
1661  genericOp.getMatchingIndexingMap(outputOperand));
1662 
1663  // Check if the operation shapes to loops map is computable.
1664  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1665  return rewriter.notifyMatchFailure(
1666  genericOp, "fused op loop bound computation failed");
1667  }
1668 
1669  // Create a constant scalar value from the splat constant.
1670  Value scalarConstant = rewriter.create<arith::ConstantOp>(
1671  def->getLoc(), constantAttr, constantAttr.getType());
1672 
1673  SmallVector<Value> outputOperands = genericOp.getOutputs();
1674  auto fusedOp = rewriter.create<GenericOp>(
1675  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1676  /*inputs=*/fusedOperands,
1677  /*outputs=*/outputOperands,
1678  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1679  genericOp.getIteratorTypes(),
1680  /*doc=*/nullptr,
1681  /*library_call=*/nullptr);
1682 
1683  // Map the block argument corresponding to the replaced argument with the
1684  // scalar constant.
1685  Region &region = genericOp->getRegion(0);
1686  Block &entryBlock = *region.begin();
1687  BlockAndValueMapping mapping;
1688  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1689  scalarConstant);
1690  Region &fusedRegion = fusedOp->getRegion(0);
1691  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1692  mapping);
1693  rewriter.replaceOp(genericOp, fusedOp->getResults());
1694  return success();
1695  }
1696  return failure();
1697  }
1698 };
1699 
1700 } // namespace
1701 
1702 //===---------------------------------------------------------------------===//
1703 // Miscellaneous patterns that help fusion.
1704 //===---------------------------------------------------------------------===//
1705 
1706 namespace {
1707 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
1708 /// value of the `outs` operand is not used within the op. This is only
1709 /// implemented for `linalg.generic` operations for now, but should hold for all
1710 /// linalg structured ops.
1711 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1713 
1714  LogicalResult matchAndRewrite(GenericOp op,
1715  PatternRewriter &rewriter) const override {
1716  rewriter.startRootUpdate(op);
1717  bool modifiedOutput = false;
1718  Location loc = op.getLoc();
1719  for (OpOperand *opOperand : op.getDpsInitOperands()) {
1720  if (!op.payloadUsesValueFromOperand(opOperand)) {
1721  Value operandVal = opOperand->get();
1722  auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1723  if (!operandType)
1724  continue;
1725 
1726  // If outs is sparse, leave it to the sparse compiler.
1728  continue;
1729 
1730  // If outs is already an `empty` operation, nothing to do.
1731  auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
1732  if (definingOp)
1733  continue;
1734  modifiedOutput = true;
1735  SmallVector<Value> dynamicDims;
1736  for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1737  if (dim.value() != ShapedType::kDynamic)
1738  continue;
1739  dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1740  loc, operandVal, dim.index()));
1741  }
1742  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1743  loc, operandType.getShape(), operandType.getElementType(),
1744  dynamicDims);
1745  op->setOperand(opOperand->getOperandNumber(), emptyTensor);
1746  }
1747  }
1748  if (!modifiedOutput) {
1749  rewriter.cancelRootUpdate(op);
1750  return failure();
1751  }
1752  rewriter.finalizeRootUpdate(op);
1753  return success();
1754  }
1755 };
1756 
1757 /// Fold linalg.fill into linalg.generic
1758 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1760 
1761  LogicalResult matchAndRewrite(GenericOp genericOp,
1762  PatternRewriter &rewriter) const override {
1763  if (!genericOp.hasTensorSemantics())
1764  return failure();
1765  bool fillFound = false;
1766  Block &payload = genericOp.getRegion().front();
1767  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1768  if (!genericOp.payloadUsesValueFromOperand(opOperand))
1769  continue;
1770  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1771  if (!fillOp)
1772  continue;
1773  fillFound = true;
1774  Value fillVal = fillOp.value();
1775  auto resultType =
1776  fillOp.result().getType().cast<RankedTensorType>().getElementType();
1777  Value convertedVal =
1778  convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
1779  /*isUnsignedCast =*/false);
1780  payload.getArgument(opOperand->getOperandNumber())
1781  .replaceAllUsesWith(convertedVal);
1782  }
1783  return success(fillFound);
1784  }
1785 };
1786 } // namespace
1787 
1789  RewritePatternSet &patterns,
1790  const ControlFusionFn &controlFoldingReshapes) {
1791  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1792  controlFoldingReshapes);
1793  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1794  controlFoldingReshapes);
1795 }
1796 
1798  RewritePatternSet &patterns,
1799  const ControlFusionFn &controlFoldingReshapes) {
1800  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1801  controlFoldingReshapes);
1802 }
1803 
1805  RewritePatternSet &patterns,
1806  const ControlFusionFn &controlElementwiseOpsFusion) {
1807  auto *context = patterns.getContext();
1808  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1809  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1810  RemoveOutsDependency>(context);
1811  // Add the patterns that clean up dead operands and results.
1813 }
1814 
1816  RewritePatternSet &patterns,
1817  const GetCollapsableDimensionsFn &controlCollapseDimensions) {
1818  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
1819  controlCollapseDimensions);
1820 }
1821 
1822 //===---------------------------------------------------------------------===//
1823 // Passes
1824 //===---------------------------------------------------------------------===//
1825 
1826 namespace {
1827 
1828 /// Pass that fuses generic ops on tensors. Used only for testing.
1829 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1830 // patterns added here heavily depends on the cost function used. Having an
1831 // opinionated pass of this form is not recommended. Deprecate this pass in
1832 // favor of test passes that check the functionality of each of the patterns
1833 // added here individually.
1834 struct LinalgElementwiseOpFusionPass
1835  : public impl::LinalgElementwiseOpFusionBase<
1836  LinalgElementwiseOpFusionPass> {
1837  void runOnOperation() override {
1838  Operation *op = getOperation();
1839  MLIRContext *context = op->getContext();
1840  RewritePatternSet patterns(context);
1841 
1842  // Add folding with reshape by expansion patterns.
1843  ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
1844  Operation *producer = fusedOperand->get().getDefiningOp();
1845  return producer && producer->hasOneUse();
1846  };
1847 
1848  // Add elementwise op fusion patterns.
1849  populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1850  populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1851 
1852  // General canonicalization patterns.
1853  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1854  GenericOp::getCanonicalizationPatterns(patterns, context);
1855  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1856  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1857  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1858  patterns);
1859 
1860  // Add constant folding patterns.
1861  populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1862 
1863  // Use TopDownTraversal for compile time reasons
1864  GreedyRewriteConfig grc;
1865  grc.useTopDownTraversal = true;
1866  (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1867  grc);
1868  }
1869 };
1870 
1871 } // namespace
1872 
1874  return std::make_unique<LinalgElementwiseOpFusionPass>();
1875 }
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices.
static Optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
static bool isDimSequencePreserved(AffineMap indexingMap, ReassociationIndicesRef dimSequence)
For a given dimSequence, check if the sequence is conserved in the indexingMap.
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static constexpr const bool value
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
U cast() const
Definition: AffineExpr.h:291
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
MLIRContext * getContext() const
Definition: AffineMap.cpp:265
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:327
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:494
unsigned getNumSymbols() const
Definition: AffineMap.cpp:310
unsigned getNumDims() const
Definition: AffineMap.cpp:306
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:319
unsigned getNumResults() const
Definition: AffineMap.cpp:314
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:323
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:530
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:455
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:524
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator begin()
Definition: Block.h:132
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:182
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:198
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
MLIRContext * getContext() const
Definition: Builders.h:54
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:300
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:472
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void setOperand(unsigned idx, Value value)
Definition: Operation.h:268
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:630
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:522
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Merge the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:493
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
U dyn_cast() const
Definition: Value.h:95
U cast() const
Definition: Value.h:105
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:158
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:190
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:74
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< Operation * > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:194
FailureOr< SmallVector< Value > > collapseGenericOpIterationDims(GenericOp genericOp, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic operation.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
std::function< SmallVector< ReassociationIndices >(linalg::GenericOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
Definition: Transforms.h:95
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:83
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:336
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:669
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:714
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above,...
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:352
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
std::unique_ptr< Pass > createLinalgElementwiseOpFusionPass()
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset