MLIR  16.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34 
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
39  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40  AffineMap fusedConsumerArgIndexMap) {
41  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42  // from consumer loop -> consumer arg tensor index/producer result tensor
43  // index. The fused loop is same as the consumer loop. For each producer arg
44  // the indexing map to be computed is a map from consumer loop -> producer
45  // arg tensor index.
46  // producerResultIndexMap is a map from producer loop -> tensor index.
47  // Compute the inverse to get map from tensor index -> producer loop.
48  // The inverse is a map from producer result tensor index -> producer loop.
49  AffineMap invProducerResultIndexMap =
50  inversePermutation(producerResultIndexMap);
51  assert(invProducerResultIndexMap &&
52  "expected producer result indexing map to be invertible");
53 
54  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55  // argMap is a map from producer loop -> producer arg tensor index.
56  AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57 
58  // Compose argMap with invProducerResultIndexMap to get a map from
59  // producer result tensor index -> producer arg tensor index.
60  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61 
62  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63  // consumer loop/ fused loop -> producer arg tensor index.
64  return t1.compose(fusedConsumerArgIndexMap);
65 }
66 
67 /// Conditions for elementwise fusion of generic operations.
68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69  OpOperand *consumerOpOperand) {
70  // Producer and consumer must have tensor semantics.
71  if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72  return false;
73 
74  // Verify that
75  // - the producer has all "parallel" iterator type.
76  if (producer.getNumParallelLoops() != producer.getNumLoops())
77  return false;
78 
79  // Only allow fusing the producer of an input operand for now.
80  // TODO: allow fusing the producer of an output operand.
81  if (!consumer.isInputTensor(consumerOpOperand))
82  return false;
83 
84  // Get the consumer index map. The number of results of the consumer index
85  // map must match the number of loops of the producer.
86  AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88  return false;
89 
90  // Currently support only operations with single result.
91  if (producer.getNumOutputs() != 1)
92  return false;
93 
94  // Finally the index_map for the result must be invertible. For now just
95  // verify it is a permutation.
96  AffineMap producerResultIndexMap =
97  producer.getTiedIndexingMap(producer.getOutputOperand(0));
98  if (!producerResultIndexMap.isPermutation())
99  return false;
100 
101  // Ensure that the fusion does not remove size information required to
102  // get the loop bounds. For non-reduction generics, this is trivially the
103  // case due to the output operand. For reductions, we need to check that after
104  // the fusion, each loop dimension has at least one input that defines it.
105  if ((consumer.getNumReductionLoops())) {
106  BitVector coveredDims(consumer.getNumLoops(), false);
107 
108  auto addToCoveredDims = [&](AffineMap map) {
109  for (auto result : map.getResults())
110  if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111  coveredDims[dimExpr.getPosition()] = true;
112  };
113 
114  for (auto pair :
115  llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
116  Value operand = std::get<0>(pair);
117  if (operand == consumerOpOperand->get())
118  continue;
119  AffineMap operandMap = std::get<1>(pair);
120  addToCoveredDims(operandMap);
121  }
122 
123  for (OpOperand *operand : producer.getInputOperands()) {
124  AffineMap newIndexingMap =
126  operand, producerResultIndexMap, consumerIndexMap);
127  addToCoveredDims(newIndexingMap);
128  }
129  if (!coveredDims.all())
130  return false;
131  }
132 
133  return true;
134 }
135 
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
140  AffineMap consumerToProducerLoopsMap,
141  OpOperand *consumerOpOperand,
142  unsigned nloops) {
143  auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145  // Build the region of the fused op.
146  Block &producerBlock = producer->getRegion(0).front();
147  Block &consumerBlock = consumer->getRegion(0).front();
148  Block *fusedBlock = new Block();
149  fusedOp.getRegion().push_back(fusedBlock);
150  BlockAndValueMapping mapper;
151  OpBuilder::InsertionGuard guard(rewriter);
152  rewriter.setInsertionPointToStart(fusedBlock);
153 
154  // 2. Add an index operation for every fused loop dimension and use the
155  // `consumerToProducerLoopsMap` to map the producer indices.
156  if (producer.hasIndexSemantics()) {
157  // Add an index operation for every fused loop dimension.
158  unsigned numFusedOpLoops =
159  std::max(producer.getNumLoops(), consumer.getNumLoops());
160  SmallVector<Value> fusedIndices;
161  fusedIndices.reserve(numFusedOpLoops);
162  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163  std::back_inserter(fusedIndices), [&](uint64_t dim) {
164  return rewriter.create<IndexOp>(producer.getLoc(), dim);
165  });
166  for (IndexOp indexOp :
167  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168  Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169  producer.getLoc(),
170  consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
171  mapper.map(indexOp.getResult(), newIndex);
172  }
173  }
174  // TODO: allow fusing the producer of an output operand.
175  assert(consumer.isInputTensor(consumerOpOperand) &&
176  "expected producer of input operand");
177  // 3. Consumer input operands up to consumerIdx (exclusive).
178  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179  consumerOpOperand->getOperandNumber())) // input assumption.
180  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181 
182  // Replacing consumerIdx requires getting the cloned, yielded, value from
183  // the (cloned) producer block. This happens in step 9.
184 
185  // 4. Splice in producer's input operands.
186  for (BlockArgument bbArg :
187  producerBlock.getArguments().take_front(producer.getNumInputs()))
188  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189 
190  // 4.b. Producer output operand/map that is fused needs to be mapped to the
191  // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192  assert(producer->getNumResults() == 1 && "expected single result producer");
193  if (producer.isInitTensor(producer.getOutputOperand(0))) {
194  BlockArgument bbArg = producerBlock.getArguments()
195  .drop_front(producer.getNumInputs())
196  // TODO: bbArg index of
197  .front();
198  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199  }
200  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201  for (BlockArgument bbArg :
202  consumerBlock.getArguments()
203  .take_front(consumer.getNumInputs())
204  .drop_front(consumerOpOperand->getOperandNumber() + 1))
205  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206  // 6. All of consumer's output operands.
207  for (BlockArgument bbArg :
208  consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210  // 7. All of producer's output operands except the one fused.
211  // TODO: allow fusion of multi-result producers.
212  assert(producer->getNumResults() == 1 && "expected single result producer");
213 
214  // 8. Clone all producer operations except for the yield and index operations
215  // to the fused operation.
216  for (auto &op : producerBlock.without_terminator()) {
217  if (!isa<IndexOp>(op))
218  rewriter.clone(op, mapper);
219  }
220  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221  // forward the yield operand.
222  auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223  // TODO: allow fusion of multi-result producers.
224  assert(producer->getNumResults() == 1 && "expected single result producer");
225  unsigned producerResultNumber = 0;
226  Value replacement =
227  mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228  // Sanity checks, if replacement is not already in the mapper then it must be
229  // produced outside.
230  if (replacement == yieldOp.getOperand(producerResultNumber)) {
231  if (auto bb = replacement.dyn_cast<BlockArgument>())
232  assert(bb.getOwner() != &producerBlock &&
233  "yielded block argument must have been mapped");
234  else
235  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236  "yielded value must have been mapped");
237  }
238  mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239  replacement);
240  // 10. Clone operations from the consumer to the fused op.
241  for (auto &op : consumerBlock.getOperations())
242  rewriter.clone(op, mapper);
243 
244  // Sanity checks.
245  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246  "Ill-formed GenericOp region");
247 }
248 
250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251  const ControlFusionFn &controlFn,
252  PatternRewriter &rewriter) {
253  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254  if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255  !controlFn(producer->getResult(0), *consumerOpOperand))
256  return llvm::None;
257 
258  // TODO: allow fusing the producer of an output operand.
259  assert(consumer.isInputTensor(consumerOpOperand) &&
260  "expected producer of input operand");
261 
262  // Compute the fused operands list and indexing maps.
263  SmallVector<Value> fusedOperands;
264  SmallVector<AffineMap> fusedIndexMaps;
265  fusedOperands.reserve(producer->getNumOperands() +
266  consumer->getNumOperands());
267  fusedIndexMaps.reserve(producer->getNumOperands() +
268  consumer->getNumOperands());
269  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271  SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
273  llvm::find(consumerInputs, consumerOpOperand);
274  assert(it != consumerInputs.end() && "expected to find the consumer operand");
275  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276  fusedOperands.push_back(opOperand->get());
277  fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278  }
279  // 4. Splice in producer's input operands/maps.
280  assert(producer->getNumResults() == 1 && "expected single result producer");
281  AffineMap producerResultIndexMap =
282  producer.getTiedIndexingMap(producer.getOutputOperand(0));
283  for (OpOperand *opOperand : producer.getInputOperands()) {
284  fusedOperands.push_back(opOperand->get());
285  // Compute indexing maps for the producer args in the fused operation.
287  opOperand, producerResultIndexMap,
288  consumer.getTiedIndexingMap(consumerOpOperand));
289  fusedIndexMaps.push_back(map);
290  }
291  // 4.b. Producer output operand/map that is fused needs to be passed if it is
292  // an "initTensor" (i.e. its value is actually read).
293  assert(producer->getNumResults() == 1 && "expected single result producer");
294  if (producer.isInitTensor(producer.getOutputOperand(0))) {
295  fusedOperands.push_back(producer.getOutputOperand(0)->get());
296  // Compute indexing maps for the producer args in the fused operation.
298  producer.getOutputOperand(0), producerResultIndexMap,
299  consumer.getTiedIndexingMap(consumerOpOperand));
300  fusedIndexMaps.push_back(map);
301  }
302  // 5. Remaining consumer's input operands/maps (drop past index
303  // `consumerIdx`).
304  for (OpOperand *opOperand :
305  llvm::make_range(std::next(it), consumerInputs.end())) {
306  fusedOperands.push_back(opOperand->get());
307  fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308  }
309  // 6. All of consumer's output operands (skip operands: added by the builder).
310  for (OpOperand *opOperand : consumer.getOutputOperands())
311  fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312  // 7. All of producer's output operands/maps except the one fused.
313  // TODO: allow fusion of multi-result producers.
314  assert(producer->getNumResults() == 1 && "expected single result producer");
315 
316  // Generate the fused op.
317  SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318  auto fusedOp = rewriter.create<GenericOp>(
319  consumer.getLoc(), consumer->getResultTypes(),
320  /*inputs=*/fusedOperands,
321  // TODO: handle outputs.
322  consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323  consumer.getIteratorTypes(),
324  /*doc=*/nullptr,
325  /*library_call=*/nullptr);
326  if (!fusedOp.getShapesToLoopsMap()) {
327  // Fused op has invalid indexing maps. Typically this means something is off
328  // in the input, but going ahead here would result in verification errors.
329  // So cleanup and abort.
330  rewriter.eraseOp(fusedOp);
331  return llvm::None;
332  }
333 
334  // Construct an AffineMap from consumer loops to producer loops.
335  // consumer loop -> tensor index
336  AffineMap consumerResultIndexMap =
337  consumer.getTiedIndexingMap(consumerOpOperand);
338  // tensor index -> producer loop
339  AffineMap invProducerResultIndexMap =
340  inversePermutation(producerResultIndexMap);
341  assert(invProducerResultIndexMap &&
342  "expected producer result indexig map to be invertible");
343  // consumer loop -> producer loop
344  AffineMap consumerToProducerLoopsMap =
345  invProducerResultIndexMap.compose(consumerResultIndexMap);
346 
347  generateFusedElementwiseOpRegion(rewriter, fusedOp,
348  consumerToProducerLoopsMap,
349  consumerOpOperand, consumer.getNumLoops());
350  return SmallVector<Value>(fusedOp->getResults());
351 }
352 
354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355  GenericOp producer, const ControlFusionFn &controlFn) {
356  if (producer->getNumResults() != 1)
357  return llvm::None;
358 
359  return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360  rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
367  FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
368  PatternBenefit benefit = 1)
369  : OpRewritePattern<GenericOp>(context, benefit),
370  controlFn(std::move(fun)) {}
371 
372  LogicalResult matchAndRewrite(GenericOp genericOp,
373  PatternRewriter &rewriter) const override {
374  // Find the first operand that is defined by another generic op on tensors.
375  for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376  auto producer =
377  dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378  if (!producer || !producer.hasTensorSemantics())
379  continue;
380  Optional<SmallVector<Value>> fusedOpResults =
381  fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382  if (fusedOpResults) {
383  rewriter.replaceOp(genericOp, *fusedOpResults);
384  return success();
385  }
386  }
387  return failure();
388  }
389 
390 private:
391  ControlFusionFn controlFn;
392 };
393 } // namespace
394 
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // expanding the dimensionality of the elementwise operations.
398 //===---------------------------------------------------------------------===//
399 
400 /// Conditions for folding a generic operation with a reshape op by expanding
401 /// the iteration space dimensionality for tensor operations. These are
402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
403 /// following fusion pattern.
404 ///
405 /// Consider
406 ///
407 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
408 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
409 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
410 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
411 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
412 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
413 ///
414 /// The reshape can be folded into the `genericOp` if its loop dimensionality
415 /// is increased to match the result (operand) of the tensor.expand_shape.
416 /// The indexing_map of the fused tensor in the `genericOp` and the
417 /// reassociation map helps compute the indexing maps of the modified op.
418 /// For the above example, based on the reassociation map it
419 /// can be concluded that
420 ///
421 /// - The loop used to access the first dimension of the fused tensor is split
422 /// into two.
423 /// - The loop used to access the second dimension of the fused tensor is kept
424 /// as is.
425 /// - The loop used to access the third dimension of the fused tensor is split
426 /// into three.
427 ///
428 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
429 /// op, then
430 ///
431 /// d0 -> e0, e1
432 /// d1 -> e2, e3, e4
433 /// d2 -> e5
434 ///
435 /// substituting this, the generic op can be rewritten as
436 ///
437 /// %d = linalg.generic ins(%0, %1 : )
438 /// indexing_maps =
439 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
440 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
441 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
442 ///
443 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
444 /// to make it consistent
445 ///
446 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
447 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
448 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
449 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
450 ///
451 /// The added reshapes are again expanding patterns, so they will get fused
452 /// with its producers if possible.
453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
454  OpOperand *fusableOpOperand) {
455  // Is fusable only if:
456  // - All the indexing maps for operands and results are projected
457  // permutations.
458  // - The fused tensor is not a scalar.
459  // - All the loops are parallel loops.
460  return genericOp.hasTensorSemantics() &&
461  llvm::all_of(genericOp.getIndexingMaps().getValue(),
462  [](Attribute attr) {
463  return attr.cast<AffineMapAttr>()
464  .getValue()
465  .isProjectedPermutation();
466  }) &&
467  genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
468  llvm::all_of(genericOp.getIteratorTypes(), [](Attribute attr) {
469  return attr.cast<StringAttr>().getValue() ==
471  });
472 }
473 
474 namespace {
475 /// Information needed to expand a generic operation to fold the reshape with
476 /// it.
477 class ExpansionInfo {
478 public:
479  // Computes the mapping from original dimensions of the op to the dimensions
480  // of the expanded op given the `indexingMap` of the fused operand/result of
481  // the generic op, the `reassocationMaps` of the reshape op and the shape of
482  // the expanded op.
483  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
484  ArrayRef<AffineMap> reassociationMaps,
485  ArrayRef<int64_t> expandedShape,
486  ArrayRef<int64_t> collapsedShape,
487  PatternRewriter &rewriter);
488  unsigned getOrigOpNumDims() const { return reassociation.size(); }
489  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
490  ReassociationIndicesRef getExpandedDims(unsigned i) const {
491  return reassociation[i];
492  }
493  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
494  return expandedShapeMap[i];
495  }
496  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
497 
498 private:
499  /// Reassociation from the dimensions in the original operation to the
500  /// dimension of the expanded operation.
501  SmallVector<ReassociationIndices> reassociation;
502  /// Mapping from extent of loops in the original operation, to the extent of
503  /// loops in the expanded operation.
504  SmallVector<SmallVector<int64_t>> expandedShapeMap;
505  /// Extent of the loop in the original operation.
506  SmallVector<int64_t> originalLoopExtent;
507  unsigned expandedOpNumDims;
508 };
509 } // namespace
510 
511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
512  OpOperand *fusableOpOperand,
513  ArrayRef<AffineMap> reassociationMaps,
514  ArrayRef<int64_t> expandedShape,
515  ArrayRef<int64_t> collapsedShape,
516  PatternRewriter &rewriter) {
517  if (reassociationMaps.empty())
518  return failure();
519  AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
520 
521  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
522  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
523 
524  reassociation.clear();
525  expandedShapeMap.clear();
526  // Compute the number of dimension in the expanded op that correspond to each
527  // dimension of the original op.
528  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
529  expandedShapeMap.resize(fusedIndexMap.getNumDims());
530  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
531  unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
532  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
533  numExpandedDims[pos] = foldedDims.getNumResults();
534  ArrayRef<int64_t> shape =
535  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
536  expandedShapeMap[pos].assign(shape.begin(), shape.end());
537  }
538  // The remaining dimensions remain the same.
539  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
540  if (expandedShapeMap[i].empty())
541  expandedShapeMap[i] = {originalLoopExtent[i]};
542 
543  // Compute reassociation map from the original op to the expanded op.
544  unsigned sum = 0;
545  reassociation.reserve(fusedIndexMap.getNumDims());
546  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
547  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
548  reassociation.emplace_back(seq.begin(), seq.end());
549  sum += numFoldedDim.value();
550  }
551  expandedOpNumDims = sum;
552  return success();
553 }
554 
555 /// Epanding the body of a linalg operation requires adaptations of the accessed
556 /// loop indices. Specifically, access of indices in the original operation need
557 /// to be replaced with linearizations of indices in the expanded op. That
558 /// requires the shape of the expanded dimensions to be static (at least all but
559 /// the most significant). For now check that these are all statically sized.
560 /// Note that this could be extended to handle dynamic case, but the
561 /// implementation below uses `affine.apply` which seems to have issues when the
562 /// shapes are not static.
563 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
564  const ExpansionInfo &expansionInfo,
565  PatternRewriter &rewriter) {
566  if (!genericOp.hasIndexSemantics())
567  return success();
568  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
569  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
570  if (expandedShape.size() == 1)
571  continue;
572  for (int64_t shape : expandedShape.drop_front()) {
573  if (ShapedType::isDynamic(shape)) {
574  return rewriter.notifyMatchFailure(
575  genericOp, "cannot expand due to index semantics and dynamic dims");
576  }
577  }
578  }
579  return success();
580 }
581 
582 /// Return the indexing map to use in the expanded op for a given the
583 /// `indexingMap` of the original operation.
584 static AffineMap
586  const ExpansionInfo &expansionInfo) {
587  SmallVector<AffineExpr> newExprs;
588  for (AffineExpr expr : indexingMap.getResults()) {
589  unsigned pos = expr.cast<AffineDimExpr>().getPosition();
590  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
591  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
592  return builder.getAffineDimExpr(static_cast<unsigned>(v));
593  }));
594  newExprs.append(expandedExprs.begin(), expandedExprs.end());
595  }
596  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
597  indexingMap.getNumSymbols(), newExprs,
598  builder.getContext());
599 }
600 
601 /// Return the type of the operand/result to use in the expanded op given the
602 /// type in the original op.
603 static RankedTensorType getExpandedType(RankedTensorType originalType,
604  AffineMap indexingMap,
605  const ExpansionInfo &expansionInfo) {
606  SmallVector<int64_t> expandedShape;
607  for (AffineExpr expr : indexingMap.getResults()) {
608  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
609  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
610  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
611  }
612  return RankedTensorType::get(expandedShape, originalType.getElementType());
613 }
614 
615 /// Returns the reassociation maps to use in the `tensor.expand_shape`
616 /// operation to convert the operands of the original operation to operands of
617 /// the expanded operation. The same method is used to compute the
618 /// `tensor.collapse_shape` used to collapse the result of the expanded
619 /// op to get the value that can replace all uses of the results of the original
620 /// op.
623  const ExpansionInfo &expansionInfo) {
624  SmallVector<ReassociationIndices> reassociation;
625  unsigned numReshapeDims = 0;
626  for (AffineExpr expr : indexingMap.getResults()) {
627  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
628  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
629  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
630  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
631  reassociation.emplace_back(std::move(indices));
632  numReshapeDims += numExpandedDims;
633  }
634  return reassociation;
635 }
636 
637 /// Update the body of an expanded linalg operation having index semantics. The
638 /// indices of the original operation need to be recovered by linearizing the
639 /// indices of the correspoding dimensions of the expanded operation. For now it
640 /// is assumed that the shapes of the expanded operation needed for
641 /// linearization are static.
643  Location loc, Region &fusedRegion,
644  const ExpansionInfo &expansionInfo) {
645  // Replace the original indices by the linearization of the expanded indices.
646  for (IndexOp indexOp :
647  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
648  ArrayRef<int64_t> expandedDims =
649  expansionInfo.getExpandedDims(indexOp.getDim());
650  assert(!expandedDims.empty() && "expected valid expansion info");
651 
652  // Skip index operations that are not affected by the expansion.
653  if (expandedDims.size() == 1 &&
654  expandedDims.front() == (int64_t)indexOp.getDim())
655  continue;
656 
657  // Linearize the expanded indices of the original index dimension.
658  OpBuilder::InsertionGuard guard(rewriter);
659  rewriter.setInsertionPointAfter(indexOp);
660  ArrayRef<int64_t> expandedDimsShape =
661  expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
662  SmallVector<Value> expandedIndices;
663  expandedIndices.reserve(expandedDims.size() - 1);
664  llvm::transform(
665  expandedDims.drop_front(), std::back_inserter(expandedIndices),
666  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
667  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
668  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
669  assert(!ShapedType::isDynamic(std::get<0>(it)));
670  AffineExpr idx, acc;
671  bindDims(rewriter.getContext(), idx, acc);
672  newIndex = rewriter.create<AffineApplyOp>(
673  indexOp.getLoc(), idx + acc * std::get<0>(it),
674  ValueRange{std::get<1>(it), newIndex});
675  }
676  rewriter.replaceOp(indexOp, newIndex);
677  }
678 }
679 
680 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
681 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
682 /// that those conditions have been satisfied.
684 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
685  OpOperand *fusableOpOperand,
686  PatternRewriter &rewriter) {
687  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
688  "preconditions for fuse operation failed");
689  // Check if reshape is expanding or collapsing.
690  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
691  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
692  bool isExpanding = (expandingReshapeOp != nullptr);
693  RankedTensorType expandedType = isExpanding
694  ? expandingReshapeOp.getResultType()
695  : collapsingReshapeOp.getSrcType();
696  RankedTensorType collapsedType = isExpanding
697  ? expandingReshapeOp.getSrcType()
698  : collapsingReshapeOp.getResultType();
699 
700  ExpansionInfo expansionInfo;
701  if (failed(expansionInfo.compute(
702  genericOp, fusableOpOperand,
703  isExpanding ? expandingReshapeOp.getReassociationMaps()
704  : collapsingReshapeOp.getReassociationMaps(),
705  expandedType.getShape(), collapsedType.getShape(), rewriter)))
706  return llvm::None;
707 
708  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
709  return llvm::None;
710 
711  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
712  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
713  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
714  }));
715 
716  SmallVector<Value> expandedOpOperands;
717  expandedOpOperands.reserve(genericOp.getNumInputs());
718  for (OpOperand *opOperand : genericOp.getInputOperands()) {
719  if (opOperand == fusableOpOperand) {
720  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
721  : collapsingReshapeOp.getSrc());
722  continue;
723  }
724  if (genericOp.isInputTensor(opOperand)) {
725  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
726  auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
727  RankedTensorType expandedOperandType =
728  getExpandedType(opOperandType, indexingMap, expansionInfo);
729  if (expandedOperandType != opOperand->get().getType()) {
730  // Reshape the operand to get the right type.
731  SmallVector<ReassociationIndices> reassociation =
732  getReassociationForExpansion(indexingMap, expansionInfo);
734  [&](const Twine &msg) {
735  return rewriter.notifyMatchFailure(genericOp, msg);
736  },
737  opOperandType.getShape(), expandedOperandType.getShape(),
738  reassociation,
739  /*isExpandingReshape=*/true)))
740  return llvm::None;
741  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
742  genericOp.getLoc(), expandedOperandType, opOperand->get(),
743  reassociation));
744  continue;
745  }
746  }
747  expandedOpOperands.push_back(opOperand->get());
748  }
749 
750  Location loc = genericOp.getLoc();
751  SmallVector<Value> outputs;
752  for (OpOperand *opOperand : genericOp.getOutputOperands()) {
753  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
754  auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
755  RankedTensorType expandedOutputType =
756  getExpandedType(opOperandType, indexingMap, expansionInfo);
757  if (expandedOutputType != opOperand->get().getType()) {
758  SmallVector<ReassociationIndices> reassociation =
759  getReassociationForExpansion(indexingMap, expansionInfo);
761  [&](const Twine &msg) {
762  return rewriter.notifyMatchFailure(genericOp, msg);
763  },
764  opOperandType.getShape(), expandedOutputType.getShape(),
765  reassociation,
766  /*isExpandingReshape=*/true)))
767  return llvm::None;
768  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
769  genericOp.getLoc(), expandedOutputType, opOperand->get(),
770  reassociation));
771  }
772  }
773 
774  // The iterator types of the expanded op are all parallel.
775  SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
777 
778  TypeRange resultTypes = ValueRange(outputs).getTypes();
779  auto fusedOp =
780  rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
781  /*inputs=*/expandedOpOperands, outputs,
782  expandedOpIndexingMaps, iteratorTypes);
783  Region &fusedRegion = fusedOp->getRegion(0);
784  Region &originalRegion = genericOp->getRegion(0);
785  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
786 
787  // Update the index accesses after the expansion.
788  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
789 
790  // Reshape the result values to their original shape if this is a collapsing
791  // reshape folded into its consumer.
792  SmallVector<Value> resultVals;
793  for (OpResult opResult : genericOp->getOpResults()) {
794  int64_t resultNumber = opResult.getResultNumber();
795  if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
796  SmallVector<ReassociationIndices> reassociation =
798  genericOp.getTiedIndexingMap(
799  genericOp.getOutputOperand(resultNumber)),
800  expansionInfo);
801  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
802  genericOp.getLoc(), opResult.getType(),
803  fusedOp->getResult(resultNumber), reassociation));
804  } else {
805  resultVals.push_back(fusedOp->getResult(resultNumber));
806  }
807  }
808  // Assuming a single result.
809  return resultVals;
810 }
811 
812 namespace {
813 
814 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
815 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
816 /// in the consumer is expanded.
817 class FoldWithProducerReshapeOpByExpansion
818  : public OpRewritePattern<GenericOp> {
819 public:
820  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
821  ControlFusionFn foldReshapes,
822  PatternBenefit benefit = 1)
823  : OpRewritePattern<GenericOp>(context, benefit),
824  controlFoldingReshapes(std::move(foldReshapes)) {}
825 
826  LogicalResult matchAndRewrite(GenericOp genericOp,
827  PatternRewriter &rewriter) const override {
828  for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
829  tensor::CollapseShapeOp reshapeOp =
830  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
831  if (!reshapeOp)
832  continue;
833  // Fold only if
834  // - The tensor reshape op is folding.
835  // - All constraints of fusing with reshape by expansion are met.
836  if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
837  (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
838  continue;
839 
840  Optional<SmallVector<Value>> replacementValues =
841  fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
842  if (!replacementValues)
843  return failure();
844  rewriter.replaceOp(genericOp, *replacementValues);
845  return success();
846  }
847  return failure();
848  }
849 
850 private:
851  ControlFusionFn controlFoldingReshapes;
852 };
853 
854 /// Pattern to fold a tensor.expand_shape op with its producer generic op
855 /// by expanding the dimensionality of the loop in the producer op.
856 struct FoldReshapeWithGenericOpByExpansion
857  : public OpRewritePattern<tensor::ExpandShapeOp> {
858 
859  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
860  ControlFusionFn foldReshapes,
861  PatternBenefit benefit = 1)
862  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
863  controlFoldingReshapes(std::move(foldReshapes)) {}
864 
865  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
866  PatternRewriter &rewriter) const override {
867  // Fold only if all constraints of fusing with reshape by expansion are met.
868  GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>();
869  if (!producer || producer.getNumOutputs() != 1 ||
871  producer.getOutputOperand(0)) ||
872  !controlFoldingReshapes(producer->getResult(0),
873  reshapeOp->getOpOperand(0)))
874  return failure();
876  producer, reshapeOp, producer.getOutputOperand(0), rewriter);
877  if (!replacementValues)
878  return failure();
879  rewriter.replaceOp(reshapeOp, *replacementValues);
880  return success();
881  }
882 
883 private:
884  ControlFusionFn controlFoldingReshapes;
885 };
886 } // namespace
887 
888 //===---------------------------------------------------------------------===//
889 // Methods and patterns to fuse reshape with linalg.generic operations by
890 // contraction of dimensions.
891 //===---------------------------------------------------------------------===//
892 
893 /// For a given list of indices in the range of the `indexingMap` that are
894 /// folded, return the indices of the corresponding domain. Return `llvm::None`
895 /// on failure. Ensures that all the elements of the returned reassociation are
896 /// distinct.
899  ReassociationIndicesRef rangeReassociation) {
900  assert(indexingMap.isProjectedPermutation() &&
901  "expected projected permutation");
902 
903  ReassociationIndices domainReassociation = llvm::to_vector<4>(
904  llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
905  return indexingMap.getResults()[pos]
906  .cast<AffineDimExpr>()
907  .getPosition();
908  }));
909  // The projected permutation semantics ensures that there is no repetition of
910  // the domain indices.
911  return domainReassociation;
912 }
913 
914 /// For a given `dimSequence`, check if the sequence is conserved in the
915 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
916 /// Non-existence of the sequence returns true as well.
917 static bool isDimSequencePreserved(AffineMap indexingMap,
918  ReassociationIndicesRef dimSequence) {
919  assert(!dimSequence.empty() &&
920  "expected non-empty list for dimension sequence");
921  assert(indexingMap.isProjectedPermutation() &&
922  "expected indexing map to be projected permutation");
923 
924  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
925  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
926 
927  unsigned dimSequenceStart = dimSequence[0];
928  for (const auto &expr : enumerate(indexingMap.getResults())) {
929  unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
930  // 1. Check if this start of the sequence.
931  if (dimInMapStart == dimSequenceStart) {
932  if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
933  return false;
934  // 1a. Check if sequence is preserved.
935  for (const auto &dimInSequence : enumerate(dimSequence)) {
936  unsigned dimInMap =
937  indexingMap.getResult(expr.index() + dimInSequence.index())
938  .cast<AffineDimExpr>()
939  .getPosition();
940  if (dimInMap != dimInSequence.value())
941  return false;
942  }
943  // Found the sequence. Projected permutation
944  // enforces that all AffineDimExprs in the result are unique, so no
945  // further checks are needed.
946  return true;
947  }
948  // 2. If position in the expr (which is of type AffineDimExpr) is part
949  // of sequence, return false here. This implies the entire sequence does not
950  // exist in the indexing map.
951  if (sequenceElements.count(dimInMapStart))
952  return false;
953  }
954  // 3. No element of sequence found. Return true.
955  return true;
956 }
957 
958 // Return the list of dimensions of the iteration domain that can be
959 // collapsed to allow for fusion with the a producer that is an expand_shape
960 // operation. If all dimensions created by expansion can be collapsed in the
961 // iteration space then the reshape is defunct.
962 //
963 // Example:
964 //
965 // ```mlir
966 // #map = affine_map<(d0, d1) -> (d0, d1)>
967 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
968 // %2 = linalg.init_tensor [..] : tensor<?x4xf32>
969 // %3 = linalg.generic {
970 // indexing_maps = [#map, #map],
971 // iterator_types = ["parallel" ,"parallel"]}
972 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
973 // ```
974 //
975 // can be fused by collapsing the dimensions of the iteration space.
976 //
977 // ```mlir
978 // #map = affine_map<(d0) -> (d0)>
979 // %2 = linalg.init_tensor [..] : tensor<?xf32>
980 // %3 = linalg.generic {
981 // indexing_maps = [#map, #map],
982 // iterator_types = ["parallel"]}
983 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
984 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
985 // ```
986 //
987 // In the following example,
988 //
989 // ```mlir
990 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
991 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
992 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
993 // %2 = linalg.init_tensor [..] : tensor<4x?xf32>
994 // %2 = linalg.generic {
995 // indexing_maps = [#map0, #map1],
996 // iterator_types = ["parallel" ,"parallel"]}
997 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
998 // ```
999 //
1000 // the reshape cannot be fused with the generic op by collapsing the op
1001 // dimensions since the indexing maps will have to contain mods and divs
1002 // to preserve the accesses pattern. When no dimensions of the iteration
1003 // space are collapsable and empty vector is returned.
1005 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1006  ArrayRef<ReassociationIndices> reassociation) {
1007  // Some basic checks for this fusion to be valid.
1008  if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1009  return {};
1010 
1011  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1012  return map.isProjectedPermutation();
1013  })) {
1014  return {};
1015  }
1016 
1017  // Compute all the loops with the reduction iterator types.
1018  SmallVector<int64_t> reductionDims;
1019  for (const auto &iteratorType :
1020  llvm::enumerate(genericOp.getIteratorTypes())) {
1021  if (isReductionIterator(iteratorType.value())) {
1022  reductionDims.push_back(iteratorType.index());
1023  }
1024  }
1025 
1026  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1027  AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1028  auto iteratorTypes = genericOp.getIteratorTypes().getValue();
1029  SmallVector<ReassociationIndices> iterationSpaceReassociation;
1030  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1031  assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1032 
1033  // Ignore dims that are not folded.
1034  if (foldedRangeDims.size() == 1)
1035  continue;
1036 
1037  ReassociationIndices foldedIterationSpaceDims =
1038  getDomainReassociation(indexingMap, foldedRangeDims);
1039 
1040  // Check that the folded iteration dims do not contain already processed
1041  // dims.
1042  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1043  return processedIterationDims.count(dim);
1044  }))
1045  continue;
1046 
1047  // Check that all folded iterator types are all parallel or all reductions.
1048  Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1049  if (!isParallelIterator(startIteratorType) &&
1050  !isReductionIterator(startIteratorType))
1051  continue;
1052  if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1053  return iteratorTypes[dim] != startIteratorType;
1054  }))
1055  continue;
1056 
1057  // If the folded dimensions correspond to a "reduction" iterator type,
1058  // the folded dimensions need to be "in-order". Strictly speaking this is
1059  // not necessary, for reductions that are associative and commutative, but
1060  // using a more strict definition of reduction for now.
1061  if (isReductionIterator(startIteratorType)) {
1062  bool isContiguous = false;
1063  for (const auto &startDim : llvm::enumerate(reductionDims)) {
1064  // Move window in `reductionDims` to start of the folded iteration dims.
1065  if (startDim.value() != foldedIterationSpaceDims[0])
1066  continue;
1067  // If sizes doesnt match, trivial not contiguous. This condition should
1068  // not be hit.
1069  if (startDim.index() + foldedIterationSpaceDims.size() >
1070  reductionDims.size())
1071  break;
1072  // Check that the contiguity is maintained.
1073  isContiguous = true;
1074  for (const auto &foldedDim :
1075  llvm::enumerate(foldedIterationSpaceDims)) {
1076  if (reductionDims[foldedDim.index() + startDim.index()] !=
1077  foldedDim.value()) {
1078  isContiguous = false;
1079  break;
1080  }
1081  }
1082  break;
1083  }
1084  if (!isContiguous)
1085  continue;
1086  }
1087 
1088  // Check that the sequence is preserved in all indexing maps.
1089  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1090  [&](AffineMap indexingMap) {
1091  return !isDimSequencePreserved(indexingMap,
1092  foldedIterationSpaceDims);
1093  }))
1094  continue;
1095 
1096  processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1097  foldedIterationSpaceDims.end());
1098  iterationSpaceReassociation.emplace_back(
1099  std::move(foldedIterationSpaceDims));
1100  }
1101 
1102  return iterationSpaceReassociation;
1103 }
1104 
1105 /// Helper class to carry state while collapsing the `linalg.generic` op.
1106 namespace {
1107 class CollapsingInfo {
1108 public:
1109  LogicalResult initialize(unsigned origNumLoops,
1110  ArrayRef<ReassociationIndices> foldedIterationDims) {
1111  llvm::SmallDenseSet<int64_t, 4> processedDims;
1112  // Find all the dims that are folded.
1113  for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1114  if (foldedIterationDim.empty())
1115  continue;
1116  // If the folded dims contain dims already folded, that's illegal
1117  // specification. Repetition within a list is also illegal.
1118  for (auto dim : foldedIterationDim) {
1119  if (dim >= origNumLoops)
1120  return failure();
1121  if (processedDims.count(dim))
1122  return failure();
1123  processedDims.insert(dim);
1124  }
1125  collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1126  foldedIterationDim.end());
1127  }
1128  if (processedDims.size() > origNumLoops)
1129  return failure();
1130 
1131  // Add all the preserved dims of the original op as single
1132  // elements to `collapsedOpToOrigOpIterationDim`.
1133  for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1134  if (processedDims.count(dim))
1135  continue;
1136  collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1137  }
1138 
1139  llvm::sort(collapsedOpToOrigOpIterationDim,
1141  return lhs[0] < rhs[0];
1142  });
1143  origOpToCollapsedOpIterationDim.resize(origNumLoops);
1144  for (const auto &foldedDims :
1145  llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1146  for (const auto &dim : enumerate(foldedDims.value()))
1147  origOpToCollapsedOpIterationDim[dim.value()] =
1148  std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1149  }
1150  return success();
1151  }
1152 
1153  /// Return mapping from collapsed loop domain to original loop domain.
1154  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1155  return collapsedOpToOrigOpIterationDim;
1156  }
1157 
1158  /// Return mapping from original loop domain to collapsed loop domain. The
1159  /// mapping is a pair. First value is the dimension in the collapsed loop that
1160  /// the original loop is mapped to. Second is the relative position in folded
1161  /// list of this domain. For example if the original loop domain is 3D, and
1162  /// the collapsed loop domain is folding all of it, i.e.
1163  ///
1164  /// ```
1165  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1166  /// ```
1167  ///
1168  /// then
1169  ///
1170  /// ```
1171  /// origOpToCollapsedOpMapping[0] = {0, 0};
1172  /// origOpToCollapsedOpMapping[1] = {0, 1};
1173  /// origOpToCollapsedOpMapping[2] = {0, 2};
1174  /// origOpToCollapsedOpMapping[3] = {1, 0};
1175  /// origOpToCollapsedOpMapping[4] = {1, 1};
1176  /// ```
1177  ///
1178  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1179  return origOpToCollapsedOpIterationDim;
1180  }
1181 
1182  /// Return the collapsed op iteration domain rank.
1183  unsigned getCollapsedOpIterationRank() const {
1184  return collapsedOpToOrigOpIterationDim.size();
1185  }
1186 
1187 private:
1188  /// Map from the iteration domain index in collapsed op to the iteration
1189  /// domain indices in the original op.
1190  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1191 
1192  /// Map from iteration domain index in the original op to the iteration domain
1193  /// index in the collapsed op.
1194  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1195 };
1196 } // namespace
1197 
1198 /// Get the iterator types for the collapsed operation given the original
1199 /// iterator types and collapsed dimensions.
1202  const CollapsingInfo &collapsingInfo) {
1203  SmallVector<StringRef> collapsedIteratorTypes;
1204  for (ReassociationIndicesRef foldedIterDims :
1205  collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1206  assert(!foldedIterDims.empty() &&
1207  "reassociation indices expected to have non-empty sets");
1208  // Just pick the iterator type of the first folded dim. Pre-condition checks
1209  // expected to have checked that iterator types of all folded dimensions are
1210  // the same.
1211  collapsedIteratorTypes.push_back(
1212  iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1213  }
1214  return collapsedIteratorTypes;
1215 }
1216 
1217 /// Compute the indexing map in the collapsed op that corresponds to the given
1218 /// `indexingMap` of the original operation.
1219 static AffineMap
1221  const CollapsingInfo &collapsingInfo) {
1222  MLIRContext *context = indexingMap.getContext();
1223  assert(indexingMap.isProjectedPermutation() &&
1224  "expected indexing map to be projected permutation");
1225  SmallVector<AffineExpr> resultExprs;
1226  auto origOpToCollapsedOpMapping =
1227  collapsingInfo.getOrigOpToCollapsedOpMapping();
1228  for (auto expr : indexingMap.getResults()) {
1229  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1230  // If the dim is not the first of the collapsed dim, do nothing.
1231  if (origOpToCollapsedOpMapping[dim].second != 0)
1232  continue;
1233  // The next n-dims are guaranteed to be collapsed. So just use the
1234  // iteration dimension of the collapsed op.
1235  resultExprs.push_back(
1236  getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1237  }
1238  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1239  resultExprs, context);
1240 }
1241 
1242 /// Return the `reassociation` indices to use to collapse the operand when the
1243 /// iteration space of a generic op is collapsed.
1246  const CollapsingInfo &collapsingInfo) {
1247  unsigned counter = 0;
1248  SmallVector<ReassociationIndices> operandReassociation;
1249  auto origOpToCollapsedOpMapping =
1250  collapsingInfo.getOrigOpToCollapsedOpMapping();
1251  auto collapsedOpToOrigOpMapping =
1252  collapsingInfo.getCollapsedOpToOrigOpMapping();
1253  while (counter < indexingMap.getNumResults()) {
1254  unsigned dim =
1255  indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1256  if (origOpToCollapsedOpMapping[dim].second == 0) {
1257  // This is the start of a collapsed dimensions of the iteration that
1258  // is gauranteed to be preserved in the indexing map. The number of folded
1259  // dims is obtained from the collapsed op to original op mapping.
1260  unsigned numFoldedDims =
1261  collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1262  .size();
1263  auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1264  operandReassociation.emplace_back(range.begin(), range.end());
1265  counter += numFoldedDims;
1266  }
1267  }
1268  return operandReassociation;
1269 }
1270 
1271 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1272 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1273  OpOperand *opOperand,
1274  const CollapsingInfo &collapsingInfo,
1275  OpBuilder &builder) {
1276  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1277  SmallVector<ReassociationIndices> operandReassociation =
1278  getOperandReassociation(indexingMap, collapsingInfo);
1279 
1280  // If the number of entries in the reassocation for the operand is same as the
1281  // number of results of the indexing map, then nothing to do for this operand.
1282  Value operand = opOperand->get();
1283  if (operandReassociation.size() == indexingMap.getNumResults())
1284  return operand;
1285 
1286  // Insert a reshape to collapse the dimensions.
1287  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1288  loc, operand, operandReassociation);
1289  return reshapeOp.getResult();
1290 }
1291 
1292 /// Modify the `linalg.index` operations in the original generic op, to its
1293 /// value in the collapsed operation.
1295  const CollapsingInfo &collapsingInfo,
1296  ValueRange loopRange,
1297  PatternRewriter &rewriter) {
1298  OpBuilder::InsertionGuard g(rewriter);
1299  rewriter.setInsertionPointToStart(block);
1300 
1301  // Collect all the original index ops.
1302  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1303 
1304  // For each folded dimension list resolve the original induction variable
1305  // values in terms of the folded dimension induction variable.
1306  // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1307  // can be inverted to
1308  // i2 = i_{folded} % d2
1309  // i1 = (i_{folded} / d2) % d1
1310  // i0 = i_{folded} / (d1 * d2)
1311  llvm::DenseMap<unsigned, Value> indexReplacementVals;
1312  for (auto &foldedDims :
1313  enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1314  ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1315  Value newIndexVal =
1316  rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1317  for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1318  indexReplacementVals[dim] =
1319  rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1320  newIndexVal =
1321  rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1322  }
1323  indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1324  }
1325 
1326  for (auto indexOp : indexOps) {
1327  auto dim = indexOp.getDim();
1328  rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1329  }
1330 }
1331 
1332 /// Implementation of fusion with reshape operation by collapsing dimensions.
1334  GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1335  OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
1336  // Bail on trivial no-op cases.
1337  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1338  llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1339  return foldedDims.size() <= 1;
1340  }))
1341  return failure();
1342 
1343  CollapsingInfo collapsingInfo;
1344  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1345  foldedIterationDims))) {
1346  return rewriter.notifyMatchFailure(
1347  genericOp, "illegal to collapse specified dimensions");
1348  }
1349 
1350  // Bail on non-canonical ranges.
1351  SmallVector<Range> loopRanges =
1352  cast<LinalgOp>(genericOp.getOperation())
1353  .createLoopRanges(rewriter, genericOp.getLoc());
1354  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1355  if (auto attr = ofr.dyn_cast<Attribute>())
1356  return attr.cast<IntegerAttr>().getInt() == value;
1357  llvm::APInt actual;
1358  return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1359  actual.getSExtValue() == value;
1360  };
1361  if (!llvm::all_of(loopRanges, [&](Range range) {
1362  return opFoldIsConstantValue(range.offset, 0) &&
1363  opFoldIsConstantValue(range.stride, 1);
1364  })) {
1365  return rewriter.notifyMatchFailure(
1366  genericOp,
1367  "expected all loop ranges to have zero start and unit stride");
1368  }
1369 
1370  // Get the iterator types for the operand.
1372  genericOp.getIteratorTypes().getValue(), collapsingInfo);
1373 
1374  // Get the indexing maps.
1375  auto indexingMaps = llvm::to_vector(
1376  llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
1377  return getCollapsedOpIndexingMap(map, collapsingInfo);
1378  }));
1379 
1380  Location loc = genericOp->getLoc();
1381 
1382  // Get the input operands.
1383  auto inputOperands = llvm::to_vector(
1384  llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1385  return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1386  rewriter);
1387  }));
1388 
1389  // Get the output operands and result types.
1390  SmallVector<Type> resultTypes;
1391  SmallVector<Value> outputOperands;
1392  resultTypes.reserve(genericOp.getNumOutputs());
1393  outputOperands.reserve(genericOp.getNumOutputs());
1394  for (OpOperand *output : genericOp.getOutputOperands()) {
1395  Value newOutput =
1396  getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1397  outputOperands.push_back(newOutput);
1398  resultTypes.push_back(newOutput.getType());
1399  }
1400 
1401  // Create the generic op.
1402  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1403  loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1404  iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1405  Block *origOpBlock = &genericOp->getRegion(0).front();
1406  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1407  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1408  collapsedOpBlock->getArguments());
1409 
1410  if (collapsedGenericOp.hasIndexSemantics()) {
1411  // Collect the loop range of the generic op.
1412  OpBuilder::InsertionGuard g(rewriter);
1413  rewriter.setInsertionPoint(collapsedGenericOp);
1414  SmallVector<Value> loopBound =
1415  llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
1416  return materializeOpFoldResult(rewriter, loc, range.size);
1417  }));
1419  &collapsedGenericOp->getRegion(0).front(),
1420  collapsingInfo, loopBound, rewriter);
1421  }
1422 
1423  // Insert expanding reshape for the result to get back the original result
1424  // type.
1425  SmallVector<Value> results;
1426  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1427  Value collapsedOpResult =
1428  collapsedGenericOp->getResult(originalResult.index());
1429  auto originalResultType =
1430  originalResult.value().getType().cast<ShapedType>();
1431  auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1432  if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1433  AffineMap indexingMap =
1434  genericOp.getTiedIndexingMapForResult(originalResult.value());
1435  SmallVector<ReassociationIndices> reassociation =
1436  getOperandReassociation(indexingMap, collapsingInfo);
1437  Value result = rewriter.create<tensor::ExpandShapeOp>(
1438  loc, originalResultType, collapsedOpResult, reassociation);
1439  results.push_back(result);
1440  } else {
1441  results.push_back(collapsedOpResult);
1442  }
1443  }
1444  return results;
1445 }
1446 
1447 namespace {
1448 
1449 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1450 /// contracting dimensions of the loop.
1451 class FoldWithProducerReshapeOpByCollapsing
1452  : public OpRewritePattern<GenericOp> {
1453 public:
1454  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1455  ControlFusionFn foldReshapes,
1456  PatternBenefit benefit = 1)
1457  : OpRewritePattern<GenericOp>(context, benefit),
1458  controlFoldingReshapes(std::move(foldReshapes)) {}
1459 
1460  LogicalResult matchAndRewrite(GenericOp genericOp,
1461  PatternRewriter &rewriter) const override {
1462  for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1463  tensor::ExpandShapeOp reshapeOp =
1464  opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1465  if (!reshapeOp)
1466  continue;
1467 
1468  SmallVector<ReassociationIndices> collapsableIterationDims =
1469  getCollapsableIterationSpaceDims(genericOp, opOperand,
1470  reshapeOp.getReassociationIndices());
1471  if (collapsableIterationDims.empty() ||
1472  !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1473  continue;
1474  }
1475 
1476  Optional<SmallVector<Value>> replacements =
1477  collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1478  opOperand, rewriter);
1479  if (!replacements) {
1480  return rewriter.notifyMatchFailure(
1481  genericOp, "failed to do the fusion by collapsing transformation");
1482  }
1483 
1484  rewriter.replaceOp(genericOp, *replacements);
1485  return success();
1486  }
1487  return failure();
1488  }
1489 
1490 private:
1491  ControlFusionFn controlFoldingReshapes;
1492 };
1493 } // namespace
1494 
1495 //===---------------------------------------------------------------------===//
1496 // Methods and patterns that fuse constants with linalg.generic operations.
1497 //===---------------------------------------------------------------------===//
1498 
1499 namespace {
1500 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1501 /// handle cases where the constant is not single-valued.
1502 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1503 public:
1504  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1505  : OpRewritePattern<GenericOp>(context, benefit) {}
1506 
1507  LogicalResult matchAndRewrite(GenericOp genericOp,
1508  PatternRewriter &rewriter) const override {
1509  if (!genericOp.hasTensorSemantics())
1510  return failure();
1511  for (OpOperand *opOperand : genericOp.getInputOperands()) {
1512  Operation *def = opOperand->get().getDefiningOp();
1513  TypedAttr constantAttr;
1514  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1515  {
1516  DenseElementsAttr splatAttr;
1517  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1518  splatAttr.isSplat() &&
1519  splatAttr.getType().getElementType().isIntOrFloat()) {
1520  constantAttr = splatAttr.getSplatValue<TypedAttr>();
1521  return true;
1522  }
1523  }
1524  {
1525  IntegerAttr intAttr;
1526  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1527  constantAttr = intAttr;
1528  return true;
1529  }
1530  }
1531  {
1532  FloatAttr floatAttr;
1533  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1534  constantAttr = floatAttr;
1535  return true;
1536  }
1537  }
1538  return false;
1539  };
1540 
1541  auto resultValue = opOperand->get().dyn_cast<OpResult>();
1542  if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1543  continue;
1544 
1545  // The operands and the indexing_maps of the fused operation the same as
1546  // the operands and indexing_maps of the generic operations with the
1547  // values at the constant index dropped.
1548  SmallVector<AffineMap> fusedIndexMaps;
1549  SmallVector<Value> fusedOperands;
1550  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1551  fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1552  fusedOperands.reserve(genericOp.getNumInputs());
1553  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1554  for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1555  if (inputOperand == opOperand)
1556  continue;
1557  Value inputValue = inputOperand->get();
1558  fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1559  fusedOperands.push_back(inputValue);
1560  fusedLocs.push_back(inputValue.getLoc());
1561  }
1562  for (OpOperand *outputOperand : genericOp.getOutputOperands())
1563  fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1564 
1565  // Check if the operation shapes to loops map is computable.
1566  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1567  return rewriter.notifyMatchFailure(
1568  genericOp, "fused op loop bound computation failed");
1569  }
1570 
1571  // Create a constant scalar value from the splat constant.
1572  Value scalarConstant = rewriter.create<arith::ConstantOp>(
1573  def->getLoc(), constantAttr, constantAttr.getType());
1574 
1575  SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1576  auto fusedOp = rewriter.create<GenericOp>(
1577  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1578  /*inputs=*/fusedOperands,
1579  /*outputs=*/outputOperands,
1580  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1581  genericOp.getIteratorTypes(),
1582  /*doc=*/nullptr,
1583  /*library_call=*/nullptr);
1584 
1585  // Map the block argument corresponding to the replaced argument with the
1586  // scalar constant.
1587  Region &region = genericOp->getRegion(0);
1588  Block &entryBlock = *region.begin();
1589  BlockAndValueMapping mapping;
1590  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1591  scalarConstant);
1592  Region &fusedRegion = fusedOp->getRegion(0);
1593  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1594  mapping);
1595  rewriter.replaceOp(genericOp, fusedOp->getResults());
1596  return success();
1597  }
1598  return failure();
1599  }
1600 };
1601 
1602 } // namespace
1603 
1604 //===---------------------------------------------------------------------===//
1605 // Miscellaneous patterns that help fusion.
1606 //===---------------------------------------------------------------------===//
1607 
1608 namespace {
1609 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1610 /// the value of the `outs` operand is not used within the op. This is only
1611 /// implemented for `linalg.generic` operations for now, but should hold for all
1612 /// linalg structured ops.
1613 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1615 
1616  LogicalResult matchAndRewrite(GenericOp op,
1617  PatternRewriter &rewriter) const override {
1618  rewriter.startRootUpdate(op);
1619  bool modifiedOutput = false;
1620  Location loc = op.getLoc();
1621  for (OpOperand *opOperand : op.getOutputOperands()) {
1622  if (!op.payloadUsesValueFromOperand(opOperand)) {
1623  Value operandVal = opOperand->get();
1624  auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1625  if (!operandType)
1626  continue;
1627 
1628  // If outs is sparse, leave it to the sparse compiler.
1630  continue;
1631 
1632  // If outs is already an `init_tensor` operation, nothing to do.
1633  auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1634  if (definingOp)
1635  continue;
1636  modifiedOutput = true;
1637  SmallVector<Value> dynamicDims;
1638  for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1639  if (dim.value() != ShapedType::kDynamicSize)
1640  continue;
1641  dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1642  loc, operandVal, dim.index()));
1643  }
1644  Value initTensor = rewriter.create<InitTensorOp>(
1645  loc, dynamicDims, operandType.getShape(),
1646  operandType.getElementType());
1647  op->setOperand(opOperand->getOperandNumber(), initTensor);
1648  }
1649  }
1650  if (!modifiedOutput) {
1651  rewriter.cancelRootUpdate(op);
1652  return failure();
1653  }
1654  rewriter.finalizeRootUpdate(op);
1655  return success();
1656  }
1657 };
1658 
1659 /// Fold linalg.fill into linalg.generic
1660 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1662 
1663  LogicalResult matchAndRewrite(GenericOp genericOp,
1664  PatternRewriter &rewriter) const override {
1665  if (!genericOp.hasTensorSemantics())
1666  return failure();
1667  bool fillFound = false;
1668  Block &payload = genericOp.getRegion().front();
1669  for (OpOperand *opOperand : genericOp.getInputOperands()) {
1670  if (!genericOp.payloadUsesValueFromOperand(opOperand))
1671  continue;
1672  FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1673  if (!fillOp)
1674  continue;
1675  fillFound = true;
1676  payload.getArgument(opOperand->getOperandNumber())
1677  .replaceAllUsesWith(fillOp.value());
1678  }
1679  return success(fillFound);
1680  }
1681 };
1682 } // namespace
1683 
1685  RewritePatternSet &patterns,
1686  const ControlFusionFn &controlFoldingReshapes) {
1687  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1688  controlFoldingReshapes);
1689  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1690  controlFoldingReshapes);
1691 }
1692 
1694  RewritePatternSet &patterns,
1695  const ControlFusionFn &controlFoldingReshapes) {
1696  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1697  controlFoldingReshapes);
1698 }
1699 
1701  RewritePatternSet &patterns,
1702  const ControlFusionFn &controlElementwiseOpsFusion) {
1703  auto *context = patterns.getContext();
1704  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1705  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1706  RemoveOutsDependency>(context);
1707 }
1708 
1709 //===---------------------------------------------------------------------===//
1710 // Passes
1711 //===---------------------------------------------------------------------===//
1712 
1713 namespace {
1714 
1715 /// Pass that fuses generic ops on tensors. Used only for testing.
1716 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1717 // patterns added here heavily depends on the cost function used. Having an
1718 // opinionated pass of this form is not recommended. Deprecate this pass in
1719 // favor of test passes that check the functionality of each of the patterns
1720 // added here individually.
1721 struct LinalgElementwiseOpFusionPass
1722  : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1723  void runOnOperation() override {
1724  Operation *op = getOperation();
1725  MLIRContext *context = op->getContext();
1726  RewritePatternSet patterns(context);
1727 
1728  // Add folding with reshape by expansion patterns.
1729  ControlFusionFn defaultControlFn = [](const OpResult &producer,
1730  const OpOperand &consumer) {
1731  return producer.hasOneUse();
1732  };
1733 
1734  // Add elementwise op fusion patterns.
1735  populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1736  populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1737 
1738  // General canonicalization patterns.
1739  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1740  GenericOp::getCanonicalizationPatterns(patterns, context);
1741  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1742  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1743  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1744  patterns);
1745 
1746  // Add constant folding patterns.
1747  populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1748 
1749  // Use TopDownTraversal for compile time reasons
1750  GreedyRewriteConfig grc;
1751  grc.useTopDownTraversal = true;
1752  (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1753  grc);
1754  }
1755 };
1756 
1757 } // namespace
1758 
1760  return std::make_unique<LinalgElementwiseOpFusionPass>();
1761 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:653
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:466
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:439
static Optional< SmallVector< Value > > fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, const ControlFusionFn &controlFn, PatternRewriter &rewriter)
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:355
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:514
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
bool isParallelIterator(Attribute attr)
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
This is a value defined by a result of an operation.
Definition: Value.h:425
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
static SmallVector< StringRef > getCollapsedOpIteratorTypes(ArrayRef< Attribute > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::unique_ptr< Pass > createLinalgElementwiseOpFusionPass()
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
OpListType & getOperations()
Definition: Block.h:128
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
static FailureOr< SmallVector< Value > > collapseGenericOpIterationDims(GenericOp genericOp, ArrayRef< ReassociationIndices > foldedIterationDims, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implementation of fusion with reshape operation by collapsing dimensions.
This class allows control over how the GreedyPatternRewriteDriver works.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape...
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
Operation & front()
Definition: Block.h:144
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:311
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:358
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op...
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
iterator begin()
Definition: Region.h:55
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
An attribute that represents a reference to a dense vector or tensor object.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
OpFoldResult size
static bool isDimSequencePreserved(AffineMap indexingMap, ReassociationIndicesRef dimSequence)
For a given dimSequence, check if the sequence is conserved in the indexingMap.
U dyn_cast() const
Definition: Types.h:270
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:478
unsigned getNumArguments()
Definition: Block.h:119
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:200
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
U dyn_cast() const
Definition: Value.h:100
type_range getTypes() const
Definition: ValueRange.cpp:44
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps...
Definition: AffineMap.cpp:698
OpFoldResult stride
Base type for affine expression.
Definition: AffineExpr.h:68
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
unsigned getNumResults() const
Definition: AffineMap.cpp:302
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
U cast() const
Definition: AffineExpr.h:291
BlockArgListType getArguments()
Definition: Block.h:76
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, OpOperand *consumerOpOperand)
Conditions for elementwise fusion of generic operations.
This class represents an argument of a Block.
Definition: Value.h:300
Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, OpFoldResult opFoldResult)
Turns an OpFoldResult into a value, creating an index-typed constant if necessary.
Definition: Utils.cpp:1010
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
OpFoldResult offset
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
void setOperand(unsigned idx, Value value)
Definition: Operation.h:268
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool isReductionIterator(Attribute attr)
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:315
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
Type getType() const
Return the type of this value.
Definition: Value.h:118
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, PatternRewriter &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
This class represents an operand of an operation.
Definition: Value.h:251
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:309
static Optional< SmallVector< Value > > fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, GenericOp producer, const ControlFusionFn &controlFn)
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
std::function< bool(const OpResult &producer, OpOperand &consumer)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:70
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
static LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices...
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static void generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *consumerOpOperand, unsigned nloops)
Generate the region of the fused tensor operation.
MLIRContext * getContext() const
Definition: AffineMap.cpp:253
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:192
std::enable_if<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T >::type getSplatValue() const
Return the splat value for this attribute.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:508
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:493
Location getLoc() const
Return the location for this argument.
Definition: Value.h:315
static Optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
MLIRContext * getContext() const
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
U cast() const
Definition: Types.h:278