MLIR  14.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
31 /// the `producer` to use in the fused operation given the indexing map of the
32 /// result of the producer in the consumer.
34  OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
35  AffineMap fusedConsumerArgIndexMap) {
36  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
37  // from consumer loop -> consumer arg tensor index/producer result tensor
38  // index. The fused loop is same as the consumer loop. For each producer arg
39  // the indexing map to be computed is a map from consumer loop -> producer
40  // arg tensor index.
41  // producerResultIndexMap is a map from producer loop -> tensor index.
42  // Compute the inverse to get map from tensor index -> producer loop.
43  // The inverse is a map from producer result tensor index -> producer loop.
44  AffineMap invProducerResultIndexMap =
45  inversePermutation(producerResultIndexMap);
46  assert(invProducerResultIndexMap &&
47  "expected producer result indexig map to be invertible");
48 
49  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
50  // argMap is a map from producer loop -> producer arg tensor index.
51  AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
52 
53  // Compose argMap with invProducerResultIndexMap to get a map from
54  // producer result tensor index -> producer arg tensor index.
55  AffineMap t1 = argMap.compose(invProducerResultIndexMap);
56 
57  // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
58  // consumer loop/ fused loop -> producer arg tensor index.
59  return t1.compose(fusedConsumerArgIndexMap);
60 }
61 
62 /// Conditions for elementwise fusion of generic operations.
63 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
64  OpOperand *consumerOpOperand) {
65  // Producer and consumer must have tensor semantics.
66  if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
67  return false;
68 
69  // Verify that
70  // - the producer has all "parallel" iterator type.
71  if (producer.getNumParallelLoops() != producer.getNumLoops())
72  return false;
73 
74  // Only allow fusing the producer of an input operand for now.
75  // TODO: allow fusing the producer of an output operand.
76  if (!consumer.isInputTensor(consumerOpOperand))
77  return false;
78 
79  // Get the consumer index map. The number of results of the consumer index
80  // map must match the number of loops of the producer.
81  AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
82  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
83  return false;
84 
85  // Currently support only operations with single result.
86  if (producer.getNumOutputs() != 1)
87  return false;
88 
89  // Finally the index_map for the result must be invertible. For now just
90  // verify it is a permutation.
91  AffineMap producerResultIndexMap =
92  producer.getTiedIndexingMap(producer.getOutputOperand(0));
93  if (!producerResultIndexMap.isPermutation())
94  return false;
95 
96  // Ensure that the fusion does not remove size information required to
97  // get the loop bounds. For non-reduction generics, this is trivially the
98  // case due to the output operand. For reductions, we need to check that after
99  // the fusion, each loop dimension has at least one input that defines it.
100  if ((consumer.getNumReductionLoops())) {
101  llvm::BitVector coveredDims(consumer.getNumLoops(), false);
102 
103  auto addToCoveredDims = [&](AffineMap map) {
104  for (auto result : map.getResults())
105  if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
106  coveredDims[dimExpr.getPosition()] = true;
107  };
108 
109  for (auto pair :
110  llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
111  Value operand = std::get<0>(pair);
112  if (operand == consumerOpOperand->get())
113  continue;
114  AffineMap operandMap = std::get<1>(pair);
115  addToCoveredDims(operandMap);
116  }
117 
118  for (OpOperand *operand : producer.getInputOperands()) {
119  AffineMap newIndexingMap =
121  operand, producerResultIndexMap, consumerIndexMap);
122  addToCoveredDims(newIndexingMap);
123  }
124  if (!coveredDims.all())
125  return false;
126  }
127 
128  return true;
129 }
130 
131 /// Generate the region of the fused tensor operation. The region of the fused
132 /// op must be empty.
133 static void
135  AffineMap consumerToProducerLoopsMap,
136  OpOperand *consumerOpOperand,
137  unsigned nloops) {
138  auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
139  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
140  // Build the region of the fused op.
141  Block &producerBlock = producer->getRegion(0).front();
142  Block &consumerBlock = consumer->getRegion(0).front();
143  Block *fusedBlock = new Block();
144  fusedOp.region().push_back(fusedBlock);
145  BlockAndValueMapping mapper;
146  OpBuilder::InsertionGuard guard(rewriter);
147  rewriter.setInsertionPointToStart(fusedBlock);
148 
149  // 2. Add an index operation for every fused loop dimension and use the
150  // `consumerToProducerLoopsMap` to map the producer indices.
151  if (producer.hasIndexSemantics()) {
152  // Add an index operation for every fused loop dimension.
153  unsigned numFusedOpLoops =
154  std::max(producer.getNumLoops(), consumer.getNumLoops());
155  SmallVector<Value> fusedIndices;
156  fusedIndices.reserve(numFusedOpLoops);
157  llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
158  std::back_inserter(fusedIndices), [&](uint64_t dim) {
159  return rewriter.create<IndexOp>(producer.getLoc(), dim);
160  });
161  for (IndexOp indexOp :
162  llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
163  Value newIndex = rewriter.create<mlir::AffineApplyOp>(
164  producer.getLoc(),
165  consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
166  mapper.map(indexOp.getResult(), newIndex);
167  }
168  }
169  // TODO: allow fusing the producer of an output operand.
170  assert(consumer.isInputTensor(consumerOpOperand) &&
171  "expected producer of input operand");
172  // 3. Consumer input operands up to consumerIdx (exclusive).
173  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
174  consumerOpOperand->getOperandNumber())) // input assumption.
175  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
176 
177  // Replacing consumerIdx requires getting the cloned, yielded, value from
178  // the (cloned) producer block. This happens in step 9.
179 
180  // 4. Splice in producer's input operands.
181  for (BlockArgument bbArg :
182  producerBlock.getArguments().take_front(producer.getNumInputs()))
183  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
184 
185  // 4.b. Producer output operand/map that is fused needs to be mapped to the
186  // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
187  assert(producer->getNumResults() == 1 && "expected single result producer");
188  if (producer.isInitTensor(producer.getOutputOperand(0))) {
189  BlockArgument bbArg = producerBlock.getArguments()
190  .drop_front(producer.getNumInputs())
191  // TODO: bbArg index of
192  .front();
193  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
194  }
195  // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
196  for (BlockArgument bbArg :
197  consumerBlock.getArguments()
198  .take_front(consumer.getNumInputs())
199  .drop_front(consumerOpOperand->getOperandNumber() + 1))
200  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
201  // 6. All of consumer's output operands.
202  for (BlockArgument bbArg :
203  consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
204  mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
205  // 7. All of producer's output operands except the one fused.
206  // TODO: allow fusion of multi-result producers.
207  assert(producer->getNumResults() == 1 && "expected single result producer");
208 
209  // 8. Clone all producer operations except for the yield and index operations
210  // to the fused operation.
211  for (auto &op : producerBlock.without_terminator()) {
212  if (!isa<IndexOp>(op))
213  rewriter.clone(op, mapper);
214  }
215  // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
216  // forward the yield operand.
217  auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
218  // TODO: allow fusion of multi-result producers.
219  assert(producer->getNumResults() == 1 && "expected single result producer");
220  unsigned producerResultNumber = 0;
221  Value replacement =
222  mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
223  // Sanity checks, if replacement is not already in the mapper then it must be
224  // produced outside.
225  if (replacement == yieldOp.getOperand(producerResultNumber)) {
226  if (auto bb = replacement.dyn_cast<BlockArgument>())
227  assert(bb.getOwner() != &producerBlock &&
228  "yielded block argument must have been mapped");
229  else
230  assert(!producer->isAncestor(replacement.getDefiningOp()) &&
231  "yielded value must have been mapped");
232  }
233  mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
234  replacement);
235  // 10. Clone operations from the consumer to the fused op.
236  for (auto &op : consumerBlock.getOperations())
237  rewriter.clone(op, mapper);
238 
239  // Sanity checks.
240  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
241  "Ill-formed GenericOp region");
242 }
243 
244 static Optional<SmallVector<Value>>
245 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
246  const ControlElementwiseOpsFusionFn &controlFn,
247  PatternRewriter &rewriter) {
248  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
249  if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
250  !controlFn(producer->getResult(0), *consumerOpOperand))
251  return llvm::None;
252 
253  // TODO: allow fusing the producer of an output operand.
254  assert(consumer.isInputTensor(consumerOpOperand) &&
255  "expected producer of input operand");
256 
257  // Compute the fused operands list and indexing maps.
258  SmallVector<Value> fusedOperands;
259  SmallVector<AffineMap> fusedIndexMaps;
260  fusedOperands.reserve(producer->getNumOperands() +
261  consumer->getNumOperands());
262  fusedIndexMaps.reserve(producer->getNumOperands() +
263  consumer->getNumOperands());
264  // In the following, numbering matches that of `generateFusedTensorOpRegion`.
265  // 3. Consumer input operands/maps up to consumerIdx (exclusive).
266  SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
267  SmallVector<OpOperand *>::iterator it =
268  llvm::find(consumerInputs, consumerOpOperand);
269  assert(it != consumerInputs.end() && "expected to find the consumer operand");
270  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
271  fusedOperands.push_back(opOperand->get());
272  fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
273  }
274  // 4. Splice in producer's input operands/maps.
275  assert(producer->getNumResults() == 1 && "expected single result producer");
276  AffineMap producerResultIndexMap =
277  producer.getTiedIndexingMap(producer.getOutputOperand(0));
278  for (OpOperand *opOperand : producer.getInputOperands()) {
279  fusedOperands.push_back(opOperand->get());
280  // Compute indexing maps for the producer args in the fused operation.
282  opOperand, producerResultIndexMap,
283  consumer.getTiedIndexingMap(consumerOpOperand));
284  fusedIndexMaps.push_back(map);
285  }
286  // 4.b. Producer output operand/map that is fused needs to be passed if it is
287  // an "initTensor" (i.e. its value is actually read).
288  assert(producer->getNumResults() == 1 && "expected single result producer");
289  if (producer.isInitTensor(producer.getOutputOperand(0))) {
290  fusedOperands.push_back(producer.getOutputOperand(0)->get());
291  // Compute indexing maps for the producer args in the fused operation.
293  producer.getOutputOperand(0), producerResultIndexMap,
294  consumer.getTiedIndexingMap(consumerOpOperand));
295  fusedIndexMaps.push_back(map);
296  }
297  // 5. Remaining consumer's input operands/maps (drop past index
298  // `consumerIdx`).
299  for (OpOperand *opOperand :
300  llvm::make_range(std::next(it), consumerInputs.end())) {
301  fusedOperands.push_back(opOperand->get());
302  fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
303  }
304  // 6. All of consumer's output operands (skip operands: added by the builder).
305  for (OpOperand *opOperand : consumer.getOutputOperands())
306  fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
307  // 7. All of producer's output operands/maps except the one fused.
308  // TODO: allow fusion of multi-result producers.
309  assert(producer->getNumResults() == 1 && "expected single result producer");
310 
311  // Generate the fused op.
312  SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
313  auto fusedOp = rewriter.create<GenericOp>(
314  consumer.getLoc(), consumer->getResultTypes(),
315  /*inputs=*/fusedOperands,
316  // TODO: handle outputs.
317  consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
318  consumer.iterator_types(),
319  /*doc=*/nullptr,
320  /*library_call=*/nullptr);
321  if (!fusedOp.getShapesToLoopsMap()) {
322  // Fused op has invalid indexing maps. Typically this means something is off
323  // in the input, but going ahead here would result in verification errors.
324  // So cleanup and abort.
325  rewriter.eraseOp(fusedOp);
326  return llvm::None;
327  }
328 
329  // Construct an AffineMap from consumer loops to producer loops.
330  // consumer loop -> tensor index
331  AffineMap consumerResultIndexMap =
332  consumer.getTiedIndexingMap(consumerOpOperand);
333  // tensor index -> producer loop
334  AffineMap invProducerResultIndexMap =
335  inversePermutation(producerResultIndexMap);
336  assert(invProducerResultIndexMap &&
337  "expected producer result indexig map to be invertible");
338  // consumer loop -> producer loop
339  AffineMap consumerToProducerLoopsMap =
340  invProducerResultIndexMap.compose(consumerResultIndexMap);
341 
342  generateFusedElementwiseOpRegion(rewriter, fusedOp,
343  consumerToProducerLoopsMap,
344  consumerOpOperand, consumer.getNumLoops());
345  return SmallVector<Value>(fusedOp->getResults());
346 }
347 
348 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
349 /// provided, given the shape of the source tensor that corresponds to the
350 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
351 /// are "row-major" ordered logically.
352 ///
353 /// For example:
354 ///
355 /// %0 = op ... : tensor<?x?x4x5xf32>
356 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
357 ///
358 /// and reshape:
359 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
360 /// tensor<?x?x4x5xf32> into tensor<?x?xf32>
361 ///
362 /// would be rewritten into:
363 /// %0 = op ... : tensor<?x?x4x5xf32>
364 /// with output index_map
365 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
366 template <typename TensorReshapeOp>
368  TensorReshapeOp reshapeOp) {
369  constexpr bool isExpanding =
371  ArrayRef<int64_t> sourceShape =
372  (isExpanding ? reshapeOp.getResultType().getShape()
373  : reshapeOp.getSrcType().getShape());
374  SmallVector<AffineExpr> resultExprs;
375  ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
376  MLIRContext *context = sourceMap.getContext();
377 
378  // Compute the result exprs based on the reassociation maps.
379  for (auto &indices : reshapeOp.getReassociationIndices()) {
380  // Assume that they are in-order and contiguous (already checked in
381  // verifier).
382  assert(!indices.empty());
383  SmallVector<int64_t> sizes;
384  SmallVector<AffineExpr> dimExprs;
385  for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
386  sourceExprs.slice(indices[0], indices.size()))) {
387  if (std::get<0>(en) == 1)
388  continue;
389  sizes.push_back(std::get<0>(en));
390  dimExprs.push_back(std::get<1>(en));
391  }
392  AffineExpr linearizedExpr =
393  makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
394  resultExprs.push_back(linearizedExpr);
395  }
396  // The new affine map cannot drop unused dimension but some new symbols may
397  // have been added. Create a map with at least as many dimensions/symbols as
398  // the original affine map.
399  int64_t maxDim = -1;
400  int64_t maxSym = -1;
401  getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
402  unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
403  unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
404  return AffineMap::get(numDims, numSyms, resultExprs, context);
405 }
406 
407 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
408 // producer). Fusing when operand has higher rank will require use of mods and
409 // divs in the indexing maps of the fused op which would make it non-invertible.
411  tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
412  if (!asProducer)
413  return false;
414  return useIndexMap.isPermutation();
415 }
416 
417 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
418 // consumer).
419 static bool
420 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
421  AffineMap useIndexMap,
422  bool asProducer) {
423  if (asProducer)
424  return false;
425  return useIndexMap.isPermutation();
426 }
427 
428 /// Check if the reshape operation is only expansion into/collapsing of
429 /// unit-dimension.
430 template <typename TensorReshapeOp>
431 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
432  constexpr bool isExpanding =
434  ArrayRef<int64_t> expandedShape =
435  (isExpanding ? reshapeOp.getResultType().getShape()
436  : reshapeOp.getSrcType().getShape());
437  for (auto &indices : reshapeOp.getReassociationIndices()) {
438  unsigned numUnitDims = 0;
439  for (int64_t position : indices)
440  if (expandedShape[position] == 1)
441  numUnitDims++;
442  if (numUnitDims != indices.size() - 1)
443  return false;
444  }
445  return true;
446 }
447 
448 /// Conditions for folding a generic operation with a reshape op by expanding
449 /// the iteration space dimensionality for tensor operations. These are
450 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
451 /// following fusion pattern.
452 ///
453 /// Consider
454 ///
455 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
456 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
457 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
458 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
459 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
460 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
461 ///
462 /// The reshape can be folded into the `genericOp` if its loop dimensionality
463 /// is increased to match the result (operand) of the tensor_expand_shape.
464 /// The indexing_map of the fused tensor in the `genericOp` and the
465 /// reassociation map helps compute the indexing maps of the modified op.
466 /// For the above example, based on the reassociation map it
467 /// can be concluded that
468 ///
469 /// - The loop used to access the first dimension of the fused tensor is split
470 /// into two.
471 /// - The loop used to access the second dimension of the fused tensor is kept
472 /// as is.
473 /// - The loop used to access the third dimension of the fused tensor is split
474 /// into three.
475 ///
476 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
477 /// op, then
478 ///
479 /// d0 -> e0, e1
480 /// d1 -> e2, e3, e4
481 /// d2 -> e5
482 ///
483 /// substituting this, the generic op can be rewritten as
484 ///
485 /// %d = linalg.generic ins(%0, %1 : )
486 /// indexing_maps =
487 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
488 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
489 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
490 ///
491 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
492 /// to make it consistent
493 ///
494 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
495 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
496 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
497 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
498 ///
499 /// The added reshapes are again expanding patterns, so they will get fused
500 /// with its producers if possible.
501 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
502  OpOperand *fusableOpOperand) {
503  // Is fusable only if:
504  // - All the indexing maps for operands and results are projected
505  // permutations.
506  // - The fused tensor is not a scalar.
507  // - All the loops are parallel loops.
508  return genericOp.hasTensorSemantics() &&
509  llvm::all_of(genericOp.indexing_maps().getValue(),
510  [](Attribute attr) {
511  return attr.cast<AffineMapAttr>()
512  .getValue()
513  .isProjectedPermutation();
514  }) &&
515  genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
516  llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
517  return attr.cast<StringAttr>().getValue() ==
519  });
520 }
521 
522 namespace {
523 /// Information needed to expand a generic operation to fold the reshape with
524 /// it.
525 class ExpansionInfo {
526 public:
527  // Computes the mapping from original dimensions of the op to the dimensions
528  // of the expanded op given the `indexingMap` of the fused operand/result of
529  // the generic op, the `reassocationMaps` of the reshape op and the shape of
530  // the expanded op.
531  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
532  ArrayRef<AffineMap> reassociationMaps,
533  ArrayRef<int64_t> expandedShape,
534  ArrayRef<int64_t> collapsedShape,
535  PatternRewriter &rewriter);
536  unsigned getOrigOpNumDims() const { return reassociation.size(); }
537  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
538  ReassociationIndicesRef getExpandedDims(unsigned i) const {
539  return reassociation[i];
540  }
541  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
542  return expandedShapeMap[i];
543  }
544  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
545 
546 private:
547  /// Reassociation from the dimensions in the original operation to the
548  /// dimension of the expanded operation.
549  SmallVector<ReassociationIndices> reassociation;
550  /// Mapping from extent of loops in the original operation, to the extent of
551  /// loops in the expanded operation.
552  SmallVector<SmallVector<int64_t>> expandedShapeMap;
553  /// Extent of the loop in the original operation.
554  SmallVector<int64_t> originalLoopExtent;
555  unsigned expandedOpNumDims;
556 };
557 } // namespace
558 
559 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
560  OpOperand *fusableOpOperand,
561  ArrayRef<AffineMap> reassociationMaps,
562  ArrayRef<int64_t> expandedShape,
563  ArrayRef<int64_t> collapsedShape,
564  PatternRewriter &rewriter) {
565  if (reassociationMaps.empty())
566  return failure();
567  AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
568 
569  Optional<SmallVector<int64_t, 4>> originalLoopRange =
570  linalgOp.getStaticLoopRanges();
571  if (!originalLoopRange)
572  return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
573  originalLoopExtent.assign(originalLoopRange->begin(),
574  originalLoopRange->end());
575 
576  reassociation.clear();
577  expandedShapeMap.clear();
578  // Compute the number of dimension in the expanded op that correspond to each
579  // dimension of the original op.
580  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
581  expandedShapeMap.resize(fusedIndexMap.getNumDims());
582  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
583  unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
584  AffineMap foldedDims = reassociationMaps[resultExpr.index()];
585  numExpandedDims[pos] = foldedDims.getNumResults();
586  ArrayRef<int64_t> shape =
587  expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
588  expandedShapeMap[pos].assign(shape.begin(), shape.end());
589  }
590  // The remaining dimensions remain the same.
591  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
592  if (expandedShapeMap[i].empty())
593  expandedShapeMap[i] = {originalLoopExtent[i]};
594 
595  // Compute reassociation map from the original op to the expanded op.
596  unsigned sum = 0;
597  reassociation.reserve(fusedIndexMap.getNumDims());
598  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
599  auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
600  reassociation.emplace_back(seq.begin(), seq.end());
601  sum += numFoldedDim.value();
602  }
603  expandedOpNumDims = sum;
604  return success();
605 }
606 
607 /// Epanding the body of a linalg operation requires adaptations of the accessed
608 /// loop indices. Specifically, access of indices in the original operation need
609 /// to be replaced with linearizations of indices in the expanded op. That
610 /// requires the shape of the expanded dimensions to be static (at least all but
611 /// the most significant). For now check that these are all statically sized.
612 /// Note that this could be extended to handle dynamic case, but the
613 /// implementation below uses `affine.apply` which seems to have issues when the
614 /// shapes are not static.
616  const ExpansionInfo &expansionInfo,
617  PatternRewriter &rewriter) {
618  if (!genericOp.hasIndexSemantics())
619  return success();
620  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
621  ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
622  if (expandedShape.size() == 1)
623  continue;
624  for (int64_t shape : expandedShape.drop_front()) {
625  if (ShapedType::isDynamic(shape)) {
626  return rewriter.notifyMatchFailure(
627  genericOp, "cannot expand due to index semantics and dynamic dims");
628  }
629  }
630  }
631  return success();
632 }
633 
634 /// Return the indexing map to use in the expanded op for a given the
635 /// `indexingMap` of the original operation.
636 static AffineMap
638  const ExpansionInfo &expansionInfo) {
639  SmallVector<AffineExpr> newExprs;
640  for (AffineExpr expr : indexingMap.getResults()) {
641  unsigned pos = expr.cast<AffineDimExpr>().getPosition();
642  SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
643  llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
644  return builder.getAffineDimExpr(static_cast<unsigned>(v));
645  }));
646  newExprs.append(expandedExprs.begin(), expandedExprs.end());
647  }
648  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
649  indexingMap.getNumSymbols(), newExprs,
650  builder.getContext());
651 }
652 
653 /// Return the type of the operand/result to use in the expanded op given the
654 /// type in the original op.
655 static RankedTensorType getExpandedType(RankedTensorType originalType,
656  AffineMap indexingMap,
657  const ExpansionInfo &expansionInfo) {
658  SmallVector<int64_t> expandedShape;
659  for (AffineExpr expr : indexingMap.getResults()) {
660  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
661  auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
662  expandedShape.append(dimExpansion.begin(), dimExpansion.end());
663  }
664  return RankedTensorType::get(expandedShape, originalType.getElementType());
665 }
666 
667 /// Returns the reassociation maps to use in the `tensor.expand_shape`
668 /// operation to convert the operands of the original operation to operands of
669 /// the expanded operation. The same method is used to compute the
670 /// `tensor.collapse_shape` used to collapse the result of the expanded
671 /// op to get the value that can replace all uses of the results of the original
672 /// op.
673 static SmallVector<ReassociationIndices>
675  const ExpansionInfo &expansionInfo) {
676  SmallVector<ReassociationIndices> reassociation;
677  unsigned numReshapeDims = 0;
678  for (AffineExpr expr : indexingMap.getResults()) {
679  unsigned dim = expr.cast<AffineDimExpr>().getPosition();
680  auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
681  SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
682  llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
683  reassociation.emplace_back(std::move(indices));
684  numReshapeDims += numExpandedDims;
685  }
686  return reassociation;
687 }
688 
689 /// Update the body of an expanded linalg operation having index semantics. The
690 /// indices of the original operation need to be recovered by linearizing the
691 /// indices of the correspoding dimensions of the expanded operation. For now it
692 /// is assumed that the shapes of the expanded operation needed for
693 /// linearization are static.
695  Location loc, Region &fusedRegion,
696  const ExpansionInfo &expansionInfo) {
697  // Replace the original indices by the linearization of the expanded indices.
698  for (IndexOp indexOp :
699  llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
700  ArrayRef<int64_t> expandedDims =
701  expansionInfo.getExpandedDims(indexOp.dim());
702  assert(!expandedDims.empty() && "expected valid expansion info");
703 
704  // Skip index operations that are not affected by the expansion.
705  if (expandedDims.size() == 1 &&
706  expandedDims.front() == (int64_t)indexOp.dim())
707  continue;
708 
709  // Linearize the expanded indices of the original index dimension.
710  OpBuilder::InsertionGuard guard(rewriter);
711  rewriter.setInsertionPointAfter(indexOp);
712  ArrayRef<int64_t> expandedDimsShape =
713  expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
714  SmallVector<Value> expandedIndices;
715  expandedIndices.reserve(expandedDims.size() - 1);
716  llvm::transform(
717  expandedDims.drop_front(), std::back_inserter(expandedIndices),
718  [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
719  Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
720  for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
721  assert(!ShapedType::isDynamic(std::get<0>(it)));
722  AffineExpr idx, acc;
723  bindDims(rewriter.getContext(), idx, acc);
724  newIndex = rewriter.create<AffineApplyOp>(
725  indexOp.getLoc(), idx + acc * std::get<0>(it),
726  ValueRange{std::get<1>(it), newIndex});
727  }
728  rewriter.replaceOp(indexOp, newIndex);
729  }
730 }
731 
732 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
733 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
734 /// that those conditions have been satisfied.
735 static Optional<SmallVector<Value>>
736 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
737  OpOperand *fusableOpOperand,
738  PatternRewriter &rewriter) {
739  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
740  "preconditions for fuse operation failed");
741  // Check if reshape is expanding or collapsing.
742  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
743  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
744  bool isExpanding = (expandingReshapeOp != nullptr);
745  RankedTensorType expandedType = isExpanding
746  ? expandingReshapeOp.getResultType()
747  : collapsingReshapeOp.getSrcType();
748  RankedTensorType collapsedType = isExpanding
749  ? expandingReshapeOp.getSrcType()
750  : collapsingReshapeOp.getResultType();
751 
752  ExpansionInfo expansionInfo;
753  if (failed(expansionInfo.compute(
754  genericOp, fusableOpOperand,
755  isExpanding ? expandingReshapeOp.getReassociationMaps()
756  : collapsingReshapeOp.getReassociationMaps(),
757  expandedType.getShape(), collapsedType.getShape(), rewriter)))
758  return llvm::None;
759 
760  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
761  return llvm::None;
762 
763  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
764  llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
765  return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
766  }));
767 
768  SmallVector<Value> expandedOpOperands;
769  expandedOpOperands.reserve(genericOp.getNumInputs());
770  for (OpOperand *opOperand : genericOp.getInputOperands()) {
771  if (opOperand == fusableOpOperand) {
772  expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
773  : collapsingReshapeOp.src());
774  continue;
775  }
776  if (genericOp.isInputTensor(opOperand)) {
777  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
778  auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
779  RankedTensorType expandedOperandType =
780  getExpandedType(opOperandType, indexingMap, expansionInfo);
781  if (expandedOperandType != opOperand->get().getType()) {
782  // Reshape the operand to get the right type.
783  SmallVector<ReassociationIndices> reassociation =
784  getReassociationForExpansion(indexingMap, expansionInfo);
786  [&](const Twine &msg) {
787  return rewriter.notifyMatchFailure(genericOp, msg);
788  },
789  opOperandType.getShape(), expandedOperandType.getShape(),
790  reassociation,
791  /*isExpandingReshape=*/true)))
792  return llvm::None;
793  expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
794  genericOp.getLoc(), expandedOperandType, opOperand->get(),
795  reassociation));
796  continue;
797  }
798  }
799  expandedOpOperands.push_back(opOperand->get());
800  }
801 
802  Location loc = genericOp.getLoc();
803  SmallVector<Value> outputs;
804  for (OpOperand *opOperand : genericOp.getOutputOperands()) {
805  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
806  auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
807  RankedTensorType expandedOutputType =
808  getExpandedType(opOperandType, indexingMap, expansionInfo);
809  if (expandedOutputType != opOperand->get().getType()) {
810  SmallVector<ReassociationIndices> reassociation =
811  getReassociationForExpansion(indexingMap, expansionInfo);
813  [&](const Twine &msg) {
814  return rewriter.notifyMatchFailure(genericOp, msg);
815  },
816  opOperandType.getShape(), expandedOutputType.getShape(),
817  reassociation,
818  /*isExpandingReshape=*/true)))
819  return llvm::None;
820  outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
821  genericOp.getLoc(), expandedOutputType, opOperand->get(),
822  reassociation));
823  }
824  }
825 
826  // The iterator types of the expanded op are all parallel.
827  SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
829 
830  TypeRange resultTypes = ValueRange(outputs).getTypes();
831  auto fusedOp =
832  rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
833  /*inputs=*/expandedOpOperands, outputs,
834  expandedOpIndexingMaps, iteratorTypes);
835  Region &fusedRegion = fusedOp->getRegion(0);
836  Region &originalRegion = genericOp->getRegion(0);
837  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
838 
839  // Update the index accesses after the expansion.
840  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
841 
842  // Reshape the result values to their original shape if this is a collapsing
843  // reshape folded into its consumer.
844  SmallVector<Value> resultVals;
845  for (OpResult opResult : genericOp->getOpResults()) {
846  int64_t resultNumber = opResult.getResultNumber();
847  if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
848  SmallVector<ReassociationIndices> reassociation =
850  genericOp.getTiedIndexingMap(
851  genericOp.getOutputOperand(resultNumber)),
852  expansionInfo);
853  resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
854  genericOp.getLoc(), opResult.getType(),
855  fusedOp->getResult(resultNumber), reassociation));
856  } else {
857  resultVals.push_back(fusedOp->getResult(resultNumber));
858  }
859  }
860  // Assuming a single result.
861  return resultVals;
862 }
863 
864 namespace {
865 
866 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
867 /// of the reshape op as the operand in the consumer (instead of the result of
868 /// the tensor_collapse_shape). The corresponding index map in the consumer
869 /// needs to be modified to linearize the folded dimension.
870 ///
871 /// For example,
872 ///
873 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
874 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
875 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
876 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
877 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
878 /// -> tensor<?x?x4x?xf32>
879 ///
880 /// can be folded into
881 ///
882 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
883 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
884 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
885 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
886 /// -> tensor<?x?x4x?xf32>
887 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
888 struct FoldProducerReshapeOpByLinearization
889  : public OpRewritePattern<GenericOp> {
891 
892  LogicalResult matchAndRewrite(GenericOp genericOp,
893  PatternRewriter &rewriter) const override {
894  if (!genericOp.hasTensorSemantics())
895  return failure();
896  SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
897  for (const auto &en : llvm::enumerate(inputOperands)) {
898  auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
899  if (!reshapeOp)
900  continue;
901 
903  reshapeOp, genericOp.getTiedIndexingMap(en.value()),
904  /*asProducer =*/true) ||
905  (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
906  continue;
907 
908  // Compute the fused operands list,
909  SmallVector<Value> fusedOperands = genericOp.getInputOperands();
910  fusedOperands[en.index()] = reshapeOp.src();
911  SmallVector<Value> outputOperands = genericOp.getOutputOperands();
912  llvm::append_range(fusedOperands, outputOperands);
913 
914  // Compute indexing_maps for the fused operation. The indexing_maps for
915  // the operands of the consumers that arent fused are the same.
916  SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
917 
918  // Compute the indexing map to use for the result of the producer.
919  AffineMap modifiedMap =
920  linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
921  // The modified map cannot have symbols.
922  if (modifiedMap.getNumSymbols())
923  return failure();
924  for (AffineExpr expr : modifiedMap.getResults()) {
925  if (!expr.isPureAffine())
926  return failure();
927  }
928  fusedIndexMaps[en.index()] = modifiedMap;
929 
930  // Further check that the resulting index maps can be fused and
931  // inverted. Without this the resultant op is not legal.
932  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
933  return rewriter.notifyMatchFailure(
934  genericOp, "fused op loop bound computation failed");
935  }
936 
937  rewriter.startRootUpdate(genericOp);
938  genericOp->setOperands(fusedOperands);
939  genericOp.indexing_mapsAttr(
940  rewriter.getAffineMapArrayAttr(fusedIndexMaps));
941  rewriter.finalizeRootUpdate(genericOp);
942  return success();
943  }
944  return failure();
945  }
946 };
947 
948 static SmallVector<ReassociationIndices>
949 getReassociationIndices(ArrayRef<AffineMap> maps) {
950  SmallVector<ReassociationIndices> reassociation;
951  for (AffineMap map : maps) {
952  ReassociationIndices indices;
953  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
954  unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
955  indices.push_back(pos);
956  }
957  reassociation.push_back(indices);
958  }
959  return reassociation;
960 }
961 
962 /// Pattern to move rank reducing reshape after an elementwise linalg generic
963 /// op. This is useful to expose more fusion opportunities between named ops and
964 /// generic ops. This can only be done if there is no broadcast or permuation
965 /// within the dimensions we need to merge.
966 ///
967 /// For example,
968 ///
969 /// %0 = tensor.expand_shape %A [[0, 1], [2]]
970 /// : tensor<12544x16xf32> into tensor<112x112x16xf32>
971 /// %2 = linalg.generic {indexing_maps = [
972 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
973 /// affine_map<(d0, d1, d2) -> (d2)>,
974 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
975 /// ["parallel", "parallel", "parallel"]} {
976 /// } -> tensor<112x112x16xf32>
977 ///
978 /// into
979 ///
980 /// %2 = linalg.generic {indexing_maps = [
981 /// affine_map<(d0, d1) -> (d0, d1)>,
982 /// affine_map<(d0, d1) -> (d1)>,
983 /// affine_map<(d0, d1) -> (d0, d1)>],
984 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
985 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
986 /// } -> tensor<12544x16xf32>
987 /// %3 = tensor.expand_shape %2 [[0, 1], [2]]
988 /// : tensor<12544x16xf32> into tensor<112x112x16xf32>
989 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
991 
992  LogicalResult matchAndRewrite(GenericOp genericOp,
993  PatternRewriter &rewriter) const override {
994  // Only apply to elementwise linalg on tensor.
995  if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
996  genericOp.getNumParallelLoops() != genericOp.getNumLoops())
997  return failure();
998  // Only support identity output maps. It could be extended to permuations if
999  // needed.
1000  if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
1001  return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
1002  }))
1003  return failure();
1004  int64_t destRank = genericOp.getNumParallelLoops();
1005  SmallVector<Value> newOperands = genericOp.getInputOperands();
1006  tensor::ExpandShapeOp reshapeFound;
1007  // 1. Look for tensor_expand_shape operands and figure out save the
1008  // dimensions merged.
1009  SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
1010  for (const auto &en : llvm::enumerate(inputOperands)) {
1011  auto reshapeOp =
1012  en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
1013  if (!reshapeOp)
1014  continue;
1015  // TODO: We could support non-identity map as long as the merged
1016  // dimensions are still contiguous.
1017  if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
1018  continue;
1019  if (reshapeFound) {
1020  // Only support a second reshape op if it has the same reassociate maps.
1021  if (reshapeFound.getReassociationMaps() ==
1022  reshapeOp.getReassociationMaps())
1023  newOperands[en.index()] = reshapeOp.src();
1024  continue;
1025  }
1026  reshapeFound = reshapeOp;
1027  newOperands[en.index()] = reshapeOp.src();
1028  }
1029  if (!reshapeFound)
1030  return failure();
1031 
1032  // Calculate the reassociation indices and rassociated reverse map.
1033  SmallVector<ReassociationIndices> reassociation =
1034  getReassociationIndices(reshapeFound.getReassociationMaps());
1035  SmallVector<unsigned> remap(destRank);
1036  for (auto &indices : llvm::enumerate(reassociation)) {
1037  for (int64_t index : indices.value()) {
1038  remap[index] = indices.index();
1039  }
1040  }
1041  // 2. Verify that we can merge the dimensions in the linalg and that we
1042  // don't need to create new reshapes operands. Inserting new reshape
1043  // operands would defeat the purpose of the transformation.
1044  for (const auto &en : llvm::enumerate(inputOperands)) {
1045  if (en.value()->get() == newOperands[en.index()]) {
1046  AffineMap map = genericOp.getTiedIndexingMap(en.value());
1047  for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1048  if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1049  return failure();
1050  }
1051  }
1052  }
1053 
1054  // 3. Calculate the affine map remapping and the reassociation to apply to
1055  // output tensors.
1056  SmallVector<AffineMap> newMaps;
1057  unsigned newRank = reassociation.size();
1058  for (auto map : genericOp.getIndexingMaps()) {
1059  SmallVector<AffineExpr> newExprs;
1060  for (auto expr : map.getResults()) {
1061  unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1062  // Skip dimension merged except for the last of the group.
1063  if (reassociation[remap[position]].back() == position) {
1064  newExprs.push_back(
1065  getAffineDimExpr(remap[position], genericOp.getContext()));
1066  }
1067  }
1068  newMaps.push_back(
1069  AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1070  }
1071 
1072  // 4. Reshape the output tensors.
1073  SmallVector<Value> newOutputs;
1074  SmallVector<Type> newOutputTypes;
1075  for (auto output : genericOp.outputs()) {
1076  auto newOutputType = RankedTensorType::get(
1077  reshapeFound.getSrcType().getShape(),
1078  output.getType().template cast<RankedTensorType>().getElementType());
1079  Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1080  genericOp->getLoc(), newOutputType, output, reassociation);
1081  newOutputTypes.push_back(newOutputType);
1082  newOutputs.push_back(newOutput);
1083  }
1084  // 5. Create a new generic op with lowerer rank.
1085  SmallVector<StringRef> iteratorTypes(newRank,
1087  auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1088  newOperands, newOutputs, newMaps,
1089  iteratorTypes);
1090  rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1091  newOp.region().begin());
1092  // 6. Reshape the so that the type matches the uses.
1093  SmallVector<Value> newResults;
1094  for (const auto &result : llvm::enumerate(newOp->getResults())) {
1095  newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1096  genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1097  result.value(), reassociation));
1098  }
1099  rewriter.replaceOp(genericOp, newResults);
1100  return success();
1101  }
1102 };
1103 
1104 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1105 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1106 /// in the consumer is expanded.
1107 class FoldWithProducerReshapeOpByExpansion
1108  : public OpRewritePattern<GenericOp> {
1109 public:
1110  FoldWithProducerReshapeOpByExpansion(
1111  MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1112  PatternBenefit benefit = 1)
1113  : OpRewritePattern<GenericOp>(context, benefit),
1114  controlFoldingReshapes(std::move(foldReshapes)) {}
1115 
1116  LogicalResult matchAndRewrite(GenericOp genericOp,
1117  PatternRewriter &rewriter) const override {
1118  for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1119  tensor::CollapseShapeOp reshapeOp =
1120  opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1121  if (!reshapeOp)
1122  continue;
1123  // Fold only if
1124  // - The tensor reshape op is folding.
1125  // - All constraints of fusing with reshape by expansion are met.
1126  if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1127  (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1128  continue;
1129 
1130  Optional<SmallVector<Value>> replacementValues =
1131  fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1132  if (!replacementValues)
1133  return failure();
1134  rewriter.replaceOp(genericOp, replacementValues.getValue());
1135  return success();
1136  }
1137  return failure();
1138  }
1139 
1140 private:
1141  ControlElementwiseOpsFusionFn controlFoldingReshapes;
1142 };
1143 
1144 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
1145 /// producer. The corresponding index map in the consumer needs to be modified
1146 /// to linearize the folded dimension.
1147 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
1148 struct FoldConsumerReshapeOpByLinearization
1149  : public OpRewritePattern<TensorReshapeOp> {
1151 
1152  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1153  PatternRewriter &rewriter) const override {
1154  GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
1155  if (!producer || !producer.hasTensorSemantics() ||
1156  producer.getNumOutputs() != 1 ||
1158  reshapeOp,
1159  producer.getTiedIndexingMap(producer.getOutputOperand(0)),
1160  /*asProducer =*/false) ||
1161  (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
1162  return failure();
1163  // The indexing_maps for the operands of the fused operation are same as
1164  // those for the operands of the producer.
1165  SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
1166 
1167  // Compute the indexing map to use for the operand of the producer.
1168  AffineMap modifiedMap = linearizeCollapsedDims(
1169  producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
1170  for (AffineExpr expr : modifiedMap.getResults()) {
1171  if (!expr.isPureAffine()) {
1172  return rewriter.notifyMatchFailure(
1173  producer, "fused op indexing map is not affine");
1174  }
1175  }
1176  fusedIndexMaps.back() = modifiedMap;
1177 
1178  // Further check that the resulting index maps can be fused and
1179  // inverted. Without this the resultant op is not legal.
1180  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1181  return rewriter.notifyMatchFailure(
1182  producer, "fused op loop bound computation failed");
1183  }
1184 
1185  Location loc = producer.getLoc();
1186  SmallVector<Value> inputOperands = producer.getInputOperands();
1187  Value output = rewriter.create<TensorReshapeOp>(
1188  loc, producer.getOutputOperand(0)->get(),
1189  reshapeOp.getReassociationExprs());
1190  auto fusedOp = rewriter.create<GenericOp>(
1191  loc, reshapeOp.getResultType(),
1192  /*inputs=*/inputOperands,
1193  // TODO: handle outputs.
1194  /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1195  producer.iterator_types(),
1196  /*doc=*/nullptr,
1197  /*library_call=*/nullptr);
1198  auto &fusedRegion = fusedOp->getRegion(0);
1199  rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
1200  fusedRegion.begin());
1201  rewriter.replaceOp(reshapeOp, fusedOp->getResults());
1202  return success();
1203  }
1204 };
1205 
1206 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1207 /// by expanding the dimensionality of the loop in the producer op.
1208 struct FoldReshapeWithGenericOpByExpansion
1209  : public OpRewritePattern<tensor::ExpandShapeOp> {
1210 
1211  FoldReshapeWithGenericOpByExpansion(
1212  MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1213  PatternBenefit benefit = 1)
1214  : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1215  controlFoldingReshapes(std::move(foldReshapes)) {}
1216 
1217  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1218  PatternRewriter &rewriter) const override {
1219  // Fold only if all constraints of fusing with reshape by expansion are met.
1220  GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1221  if (!producer || producer.getNumOutputs() != 1 ||
1223  producer.getOutputOperand(0)) ||
1224  !controlFoldingReshapes(producer->getResult(0),
1225  reshapeOp->getOpOperand(0)))
1226  return failure();
1227  Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1228  producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1229  if (!replacementValues)
1230  return failure();
1231  rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1232  return success();
1233  }
1234 
1235 private:
1236  ControlElementwiseOpsFusionFn controlFoldingReshapes;
1237 };
1238 
1239 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1240 /// handle cases where the constant is not single-valued.
1241 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1242 public:
1243  FoldScalarOrSplatConstant(MLIRContext *context,
1245  PatternBenefit benefit = 1)
1246  : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1247 
1248  LogicalResult matchAndRewrite(GenericOp genericOp,
1249  PatternRewriter &rewriter) const override {
1250  if (!genericOp.hasTensorSemantics())
1251  return failure();
1252  for (OpOperand *opOperand : genericOp.getInputOperands()) {
1253  Operation *def = opOperand->get().getDefiningOp();
1254  Attribute constantAttr;
1255  auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1256  {
1257  DenseElementsAttr splatAttr;
1258  if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1259  splatAttr.isSplat() &&
1260  splatAttr.getType().getElementType().isIntOrFloat()) {
1261  constantAttr = splatAttr.getSplatValue<Attribute>();
1262  return true;
1263  }
1264  }
1265  {
1266  IntegerAttr intAttr;
1267  if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1268  constantAttr = intAttr;
1269  return true;
1270  }
1271  }
1272  {
1273  FloatAttr floatAttr;
1274  if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1275  constantAttr = floatAttr;
1276  return true;
1277  }
1278  }
1279  return false;
1280  };
1281 
1282  auto resultValue = opOperand->get().dyn_cast<OpResult>();
1283  if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1284  !controlFn(resultValue, *opOperand))
1285  continue;
1286 
1287  // The operands and the indexing_maps of the fused operation the same as
1288  // the operands and indexing_maps of the generic operations with the
1289  // values at the constant index dropped.
1290  SmallVector<AffineMap> fusedIndexMaps;
1291  SmallVector<Value> fusedOperands;
1292  SmallVector<Location> fusedLocs{genericOp.getLoc()};
1293  fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1294  fusedOperands.reserve(genericOp.getNumInputs());
1295  fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1296  for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1297  if (inputOperand == opOperand)
1298  continue;
1299  Value inputValue = inputOperand->get();
1300  fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1301  fusedOperands.push_back(inputValue);
1302  fusedLocs.push_back(inputValue.getLoc());
1303  }
1304  for (OpOperand *outputOperand : genericOp.getOutputOperands())
1305  fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1306 
1307  // Check if the operation shapes to loops map is computable.
1308  if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1309  return rewriter.notifyMatchFailure(
1310  genericOp, "fused op loop bound computation failed");
1311  }
1312 
1313  // Create a constant scalar value from the splat constant.
1314  Value scalarConstant = rewriter.create<arith::ConstantOp>(
1315  def->getLoc(), constantAttr, constantAttr.getType());
1316 
1317  SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1318  auto fusedOp = rewriter.create<GenericOp>(
1319  rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1320  /*inputs=*/fusedOperands,
1321  /*outputs=*/outputOperands,
1322  rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1323  genericOp.iterator_types(),
1324  /*doc=*/nullptr,
1325  /*library_call=*/nullptr);
1326 
1327  // Map the block argument corresponding to the replaced argument with the
1328  // scalar constant.
1329  Region &region = genericOp->getRegion(0);
1330  Block &entryBlock = *region.begin();
1331  BlockAndValueMapping mapping;
1332  mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1333  scalarConstant);
1334  Region &fusedRegion = fusedOp->getRegion(0);
1335  rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1336  mapping);
1337  rewriter.replaceOp(genericOp, fusedOp->getResults());
1338  return success();
1339  }
1340  return failure();
1341  }
1342 
1343 private:
1345 };
1346 
1347 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1348 /// and permutation indexing maps.
1349 ///
1350 /// `ConcreteType` should provide methods with signatures
1351 ///
1352 /// ```c++
1353 /// bool matchIndexingMaps(GenericOp genericOp) const;
1354 /// RegionComputationFn getRegionComputeFn(GenericOp) const;
1355 /// ```
1356 ///
1357 /// The latter inspects the region and returns the computation inside as a
1358 /// functor. The functor will be invoked with constant elements for all inputs
1359 /// and should return the corresponding computea constant element for output.
1360 template <typename ConcreteType>
1361 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1362 public:
1363  struct APIntOrFloat {
1364  Optional<APInt> apInt;
1365  Optional<APFloat> apFloat;
1366  };
1367  struct APIntOrFloatArray {
1368  SmallVector<APInt> apInts;
1369  SmallVector<APFloat> apFloats;
1370  };
1371  using RegionComputationFn =
1372  std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1373 
1374  FoldConstantBase(MLIRContext *context,
1375  const ControlElementwiseOpsFusionFn &controlFn,
1376  PatternBenefit benefit = 1)
1377  : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1378 
1379  LogicalResult matchAndRewrite(GenericOp genericOp,
1380  PatternRewriter &rewriter) const override {
1381  if (genericOp.hasBufferSemantics())
1382  return failure();
1383 
1384  // Only support ops generating one output for now.
1385  if (genericOp.getNumOutputs() != 1)
1386  return failure();
1387 
1388  auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1389  // Require the output types to be static give we are generating constants.
1390  if (!outputType || !outputType.hasStaticShape())
1391  return failure();
1392 
1393  if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1394  return operand->get().getType().isa<ShapedType>();
1395  }))
1396  return failure();
1397 
1398  // Make sure all element types are the same.
1399  auto getOperandElementType = [](OpOperand *operand) {
1400  return operand->get().getType().cast<ShapedType>().getElementType();
1401  };
1402  if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1403  getOperandElementType)))
1404  return failure();
1405 
1406  // We can only handle the case where we have int/float elements.
1407  auto elementType = outputType.getElementType();
1408  if (!elementType.isIntOrFloat())
1409  return failure();
1410 
1411  // Require all indexing maps to be permutations for now. This is common and
1412  // it simplifies input/output access greatly: we can do the data shuffling
1413  // entirely in the compiler, without needing to turn all indices into
1414  // Values, and then do affine apply on them, and then match back the
1415  // constant again.
1416  if (!llvm::all_of(genericOp.getIndexingMaps(),
1417  [](AffineMap map) { return map.isPermutation(); }))
1418  return failure();
1419 
1420  for (OpOperand *operand : genericOp.getOutputOperands()) {
1421  if (genericOp.payloadUsesValueFromOperand(operand))
1422  return failure();
1423  }
1424 
1425  // Further check the indexing maps are okay for the ConcreteType.
1426  if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1427  return failure();
1428 
1429  // Defer to the concrete type to check the region and discover the
1430  // computation inside.
1431  RegionComputationFn computeFn =
1432  static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1433  if (!computeFn)
1434  return failure();
1435 
1436  // All inputs should be constants.
1437  int numInputs = genericOp.getNumInputs();
1438  SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1439  for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
1440  if (!matchPattern(operand.value()->get(),
1441  m_Constant(&inputValues[operand.index()])))
1442  return failure();
1443  }
1444 
1445  // Identified this as a potential candidate for folding. Now check the
1446  // policy to see whether we are allowed to proceed.
1447  for (int i = 0; i < numInputs; ++i) {
1448  OpOperand *consumer = genericOp.getInputOperand(i);
1449  OpResult producer = consumer->get().cast<OpResult>();
1450  if (!controlFn(producer, *consumer))
1451  return failure();
1452  }
1453 
1454  auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1455  SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1456  int64_t numElements = outputType.getNumElements();
1457 
1458  // Use APInt/APFloat instead of Attribute here for constructing the output.
1459  // This helps to avoid blowing up compiler memory usage: Attributes would
1460  // unify the following cases but they have lifetime as the MLIRContext.
1461  SmallVector<APInt> intOutputValues;
1462  SmallVector<APFloat> fpOutputValues;
1463  if (elementType.template isa<FloatType>())
1464  fpOutputValues.resize(numElements, APFloat(0.f));
1465  else
1466  intOutputValues.resize(numElements);
1467 
1468  // Return the constant dim positions from the given permutation map.
1469  auto getDimPositions = [](AffineMap map) {
1470  SmallVector<unsigned> dims;
1471  dims.reserve(map.getNumResults());
1472  for (AffineExpr result : map.getResults()) {
1473  dims.push_back(result.cast<AffineDimExpr>().getPosition());
1474  }
1475  return dims;
1476  };
1477 
1478  SmallVector<SmallVector<unsigned>> inputDims;
1479  for (int i = 0; i < numInputs; ++i)
1480  inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
1481  auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
1482  auto outputShape = outputType.getShape();
1483 
1484  // Allocate small vectors for index delinearization. Initial values do not
1485  // matter here as they will be overwritten later.
1486  SmallVector<uint64_t> indices(loopBounds.size(), 0);
1487  SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1488  SmallVector<SmallVector<uint64_t>> srcIndices(
1489  numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1490  SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1491  uint64_t dstLinearIndex = 0;
1492 
1493  // Allocate spaces for compute function inputs. Initial values do not matter
1494  // here as they will be overwritten later.
1495  APIntOrFloatArray computeFnInputs;
1496 
1497  auto inputShapes = llvm::to_vector<4>(
1498  llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
1499  return operand->get().getType().cast<ShapedType>().getShape();
1500  }));
1501 
1502  // Given a `linearIndex`, remap it to a linear index to access linalg op
1503  // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1504  // `srcLinearIndices`, `dstLinearIndex` in place.
1505  auto computeRemappedLinearIndex = [&](int linearIndex) {
1506  int totalCount = linearIndex;
1507  for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1508  indices[dim] = totalCount % loopBounds[dim];
1509  totalCount /= loopBounds[dim];
1510  }
1511 
1512  for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1513  for (int i = 0; i < numInputs; ++i)
1514  srcIndices[i][dim] = indices[inputDims[i][dim]];
1515  dstIndices[dim] = indices[outputDims[dim]];
1516  }
1517 
1518  dstLinearIndex = dstIndices.front();
1519  for (int i = 0; i < numInputs; ++i)
1520  srcLinearIndices[i] = srcIndices[i].front();
1521 
1522  for (int dim = 1; dim < outputType.getRank(); ++dim) {
1523  dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
1524  for (int i = 0; i < numInputs; ++i)
1525  srcLinearIndices[i] =
1526  srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
1527  }
1528  };
1529 
1530  bool isFloat = elementType.isa<FloatType>();
1531  if (isFloat) {
1532  SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
1533  for (int i = 0; i < numInputs; ++i)
1534  inFpRanges.push_back(inputValues[i].getValues<APFloat>());
1535 
1536  computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
1537 
1538  // Transpose the input constant. Because we don't know its rank in
1539  // advance, we need to loop over the range [0, element count) and
1540  // delinearize the index.
1541  for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1542  computeRemappedLinearIndex(linearIndex);
1543 
1544  // Collect constant elements for all inputs at this loop iteration.
1545  for (int i = 0; i < numInputs; ++i)
1546  computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
1547 
1548  // Invoke the computation to get the corresponding constant output
1549  // element.
1550  fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
1551  }
1552  } else {
1553  SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
1554  for (int i = 0; i < numInputs; ++i)
1555  inIntRanges.push_back(inputValues[i].getValues<APInt>());
1556 
1557  computeFnInputs.apInts.resize(numInputs);
1558 
1559  // Transpose the input constant. Because we don't know its rank in
1560  // advance, we need to loop over the range [0, element count) and
1561  // delinearize the index.
1562  for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1563  computeRemappedLinearIndex(linearIndex);
1564 
1565  // Collect constant elements for all inputs at this loop iteration.
1566  for (int i = 0; i < numInputs; ++i)
1567  computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
1568 
1569  // Invoke the computation to get the corresponding constant output
1570  // element.
1571  intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
1572  }
1573  }
1574 
1575  DenseElementsAttr outputAttr =
1576  isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
1577  : DenseElementsAttr::get(outputType, intOutputValues);
1578 
1579  rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
1580  return success();
1581  }
1582 
1583 private:
1585 };
1586 
1587 // Folds linalg.generic ops that are actually transposes on constant values.
1588 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
1589  using FoldConstantBase::FoldConstantBase;
1590 
1591  bool matchIndexingMaps(GenericOp genericOp) const {
1592  // We should have one input and one output.
1593  return genericOp.getIndexingMaps().size() == 2;
1594  }
1595 
1596  RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
1597  // Make sure the region only contains a yield op.
1598  Block &body = genericOp.region().front();
1599  if (!llvm::hasSingleElement(body))
1600  return nullptr;
1601  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1602  if (!yieldOp)
1603  return nullptr;
1604 
1605  // The yield op should return the block argument corresponds to the input.
1606  for (Value yieldVal : yieldOp.values()) {
1607  auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
1608  if (!yieldArg || yieldArg.getOwner() != &body)
1609  return nullptr;
1610  if (yieldArg.getArgNumber() != 0)
1611  return nullptr;
1612  }
1613 
1614  // No computation; just return the orginal value.
1615  return [](const APIntOrFloatArray &inputs) {
1616  if (inputs.apFloats.empty())
1617  return APIntOrFloat{inputs.apInts.front(), llvm::None};
1618  return APIntOrFloat{llvm::None, inputs.apFloats.front()};
1619  };
1620  }
1621 
1623 };
1624 
1625 } // namespace
1626 
1627 static Optional<SmallVector<Value>>
1628 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
1629  GenericOp producer,
1630  const ControlElementwiseOpsFusionFn &controlFn) {
1631  if (producer->getNumResults() != 1)
1632  return llvm::None;
1633 
1634  return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
1635  rewriter);
1636 }
1637 
1639  OpOperand &consumer) {
1640  if (auto producerCollapseOp =
1641  dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
1642  return !isUnitDimExpansionOnly(producerCollapseOp);
1643  }
1644  if (auto consumerExpandOp =
1645  dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
1646  return !isUnitDimExpansionOnly(consumerExpandOp);
1647  }
1648  return true;
1649 }
1650 
1651 namespace {
1652 /// Patterns to fuse a generic op, with the producer of its operands.
1653 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
1654 public:
1655  FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1656  PatternBenefit benefit = 1)
1657  : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1658 
1659  LogicalResult matchAndRewrite(GenericOp genericOp,
1660  PatternRewriter &rewriter) const override {
1661  // Find the first operand that is defined by another generic op on tensors.
1662  for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
1663  auto producer =
1664  dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
1665  if (!producer || !producer.hasTensorSemantics())
1666  continue;
1667  Optional<SmallVector<Value>> fusedOpResults =
1668  fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
1669  if (fusedOpResults) {
1670  rewriter.replaceOp(genericOp, *fusedOpResults);
1671  return success();
1672  }
1673  }
1674  return failure();
1675  }
1676 
1677 private:
1679 };
1680 
1681 /// Pass that fuses generic ops on tensors. Used only for testing.
1682 struct LinalgElementwiseOpFusionPass
1683  : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1684  void runOnOperation() override {
1685  Operation *op = getOperation();
1686  RewritePatternSet patterns(op->getContext());
1687  ControlElementwiseOpsFusionFn allowFoldingFn =
1688  [](const OpResult &producer, const OpOperand &consumer) {
1689  return true;
1690  };
1692  patterns,
1693  LinalgElementwiseFusionOptions().setControlFoldingReshapes(
1694  allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
1695 
1696  // Use TopDownTraversal for compile time reasons
1697  GreedyRewriteConfig grc;
1698  grc.useTopDownTraversal = true;
1699  (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1700  grc);
1701  }
1702 };
1703 
1704 /// Pass to test folding of reshape ops with generic ops by linearization.
1705 struct FoldReshapeOpsByLinearizationPass
1706  : public LinalgFoldReshapeOpsByLinearizationBase<
1707  FoldReshapeOpsByLinearizationPass> {
1708  void runOnOperation() override {
1709  Operation *op = getOperation();
1710  RewritePatternSet patterns(op->getContext());
1712  if (allowFoldingUnitDimReshapes) {
1714  }
1715  (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1716  }
1717 };
1718 
1719 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1720 /// the value of the `outs` operand is not used within the op. This is only
1721 /// implemented for `linalg.generic` operations for now, but should hold for all
1722 /// linalg structured ops.
1723 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1725 
1726  LogicalResult matchAndRewrite(GenericOp op,
1727  PatternRewriter &rewriter) const override {
1728  rewriter.startRootUpdate(op);
1729  bool modifiedOutput = false;
1730  Location loc = op.getLoc();
1731  for (OpOperand *opOperand : op.getOutputOperands()) {
1732  if (!op.payloadUsesValueFromOperand(opOperand)) {
1733  Value operandVal = opOperand->get();
1734  auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1735  if (!operandType)
1736  continue;
1737 
1738  // If outs is already an `init_tensor` operation, nothing to do.
1739  auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1740  if (definingOp)
1741  continue;
1742  modifiedOutput = true;
1743  SmallVector<Value> dynamicDims;
1744  for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1745  if (dim.value() != ShapedType::kDynamicSize)
1746  continue;
1747  dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1748  loc, operandVal, dim.index()));
1749  }
1750  Value initTensor = rewriter.create<InitTensorOp>(
1751  loc, dynamicDims, operandType.getShape(),
1752  operandType.getElementType());
1753  op->setOperand(opOperand->getOperandNumber(), initTensor);
1754  }
1755  }
1756  if (!modifiedOutput) {
1757  rewriter.cancelRootUpdate(op);
1758  return failure();
1759  }
1760  rewriter.finalizeRootUpdate(op);
1761  return success();
1762  }
1763 };
1764 
1765 } // namespace
1766 
1768  RewritePatternSet &patterns) {
1769  patterns
1770  .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1771  FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
1772  FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1773  FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
1774  patterns.getContext());
1775 }
1776 
1778  RewritePatternSet &patterns) {
1779  patterns
1780  .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1781  FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
1782  FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1783  FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
1784  patterns.getContext());
1785 }
1786 
1788  RewritePatternSet &patterns,
1789  const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
1790  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1791  controlFoldingReshapes);
1792  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1793  controlFoldingReshapes);
1794 }
1795 
1798  auto *context = patterns.getContext();
1799  patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
1800  FoldConstantTranspose>(context,
1802  patterns.add<RemoveOutsDependency>(context);
1804  options.controlFoldingReshapesFn);
1805  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1806  GenericOp::getCanonicalizationPatterns(patterns, context);
1807  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1808  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1809  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1810  patterns);
1811 }
1812 
1814  auto *context = patterns.getContext();
1815  patterns.add<PushExpandingReshape>(context);
1816 }
1817 
1819  return std::make_unique<LinalgElementwiseOpFusionPass>();
1820 }
1821 
1823  return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1824 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
iterator begin()
Definition: Block.h:134
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:673
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:444
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:444
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:519
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
This is a value defined by a result of an operation.
Definition: Value.h:423
Options that control fusion of elementwise operations.
Definition: Transforms.h:122
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::unique_ptr< Pass > createLinalgElementwiseOpFusionPass()
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:457
OpListType & getOperations()
Definition: Block.h:128
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class allows control over how the GreedyPatternRewriteDriver works.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape...
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
Operation & front()
Definition: Block.h:144
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:774
unsigned getPosition() const
Definition: AffineExpr.cpp:312
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, LinalgElementwiseFusionOptions options=LinalgElementwiseFusionOptions())
Patterns for fusing linalg operation on tensors.
std::unique_ptr< Pass > createFoldReshapeOpsByLinearizationPass()
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:432
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op...
static Optional< SmallVector< Value > > fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, GenericOp producer, const ControlElementwiseOpsFusionFn &controlFn)
ControlElementwiseOpsFusionFn controlFoldingReshapesFn
Enable fusion of reshapes into the shape with elementwise operations.
Definition: Transforms.h:125
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:258
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
iterator begin()
Definition: Region.h:55
An attribute that represents a reference to a dense vector or tensor object.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:244
static Optional< SmallVector< Value > > fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, const ControlElementwiseOpsFusionFn &controlFn, PatternRewriter &rewriter)
unsigned getNumArguments()
Definition: Block.h:119
std::function< bool(const OpResult &producer, OpOperand &consumer)> ControlElementwiseOpsFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:66
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
U dyn_cast() const
Definition: Value.h:99
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps...
Definition: AffineMap.cpp:718
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap linearizeCollapsedDims(AffineMap sourceMap, TensorReshapeOp reshapeOp)
Linearize the expressions in sourceMap based on the reassociationMaps provided, given the shape of th...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
unsigned getNumResults() const
Definition: AffineMap.cpp:302
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp)
Check if the reshape operation is only expansion into/collapsing of unit-dimension.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
BlockArgListType getArguments()
Definition: Block.h:76
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, OpOperand *consumerOpOperand)
Conditions for elementwise fusion of generic operations.
This class represents an argument of a Block.
Definition: Value.h:298
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer)
Default function to control reshape folding.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:779
ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn
Function to allow the caller to control when to stop fusion.
Definition: Transforms.h:137
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:320
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static bool isTensorReshapeOpFoldableByLinearization(tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer)
static llvm::ManagedStatic< PassManagerOptions > options
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlElementwiseOpsFusionFn &controlFoldingReshapes=skipUnitDimReshape)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
U dyn_cast() const
Definition: Attributes.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
type_range getTypes() const
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:37
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
This class represents an operand of an operation.
Definition: Value.h:249
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:285
U cast() const
Definition: Value.h:107
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns)
Patterns to fold a collapsing (expanding) tensor_reshape operation with its producer (consumer) gener...
static void generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *consumerOpOperand, unsigned nloops)
Generate the region of the fused tensor operation.
MLIRContext * getContext() const
Definition: AffineMap.cpp:253
LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
Definition: Builders.h:177
std::enable_if<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T >::type getSplatValue() const
Return the splat value for this attribute.
This class provides an abstraction over the different types of ranges over Values.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:513
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:783
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns)
Patterns to push reshape op towards the end of the graph in order to expose more fusion opportunities...
Location getLoc() const
Return the location for this argument.
Definition: Value.h:313
static Optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op and a generic op as expl...
MLIRContext * getContext() const
Definition: PatternMatch.h:906
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns)
Patterns to fold a collapsing (expanding) tensor_reshape operation with its producer (consumer) gener...
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)