MLIR  21.0.0git
SwapExtractSliceWithProducerPatterns.cpp
Go to the documentation of this file.
1 //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Swap a `tensor.extract_slice` with the producer of the source if the producer
10 // implements the `TilingInterface`. When used in conjunction with tiling this
11 // effectively tiles + fuses the producer with its consumer.
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "tensor-swap-slices"
23 
24 using namespace mlir;
25 
27  OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
28  auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
29  if (!producerOp)
30  return failure();
31 
32  // `TilingInterface` currently only supports strides being 1.
33  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
34  return failure();
35 
36  FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
37  builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
38  sliceOp.getMixedSizes());
39  if (failed(tiledResult))
40  return failure();
41 
42  return *tiledResult;
43 }
44 
46  OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
47  ArrayRef<OpOperand *> consumerOperands) {
48  if (sliceOps.empty()) {
49  LLVM_DEBUG(
50  { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
51  return failure();
52  }
53  if (sliceOps.size() != consumerOperands.size()) {
54  LLVM_DEBUG({
55  llvm::dbgs()
56  << "expected as many operands as the number of slices passed";
57  });
58  return failure();
59  }
60  auto consumerOp =
61  dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
62  if (!consumerOp)
63  return failure();
64  for (auto opOperand : consumerOperands.drop_front()) {
65  if (opOperand->getOwner() != consumerOp) {
66  LLVM_DEBUG({
67  llvm::dbgs()
68  << "expected all consumer operands to be from the same operation";
69  });
70  return failure();
71  }
72  }
73 
74  auto consumerOperandNums = llvm::map_to_vector(
75  consumerOperands, [](OpOperand *opOperand) -> unsigned {
76  return opOperand->getOperandNumber();
77  });
80  for (auto sliceOp : sliceOps) {
81 
82  // `TilingInterface` currently only supports strides being 1.
83  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
84  return failure();
85 
86  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
87  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
88  allOffsets.emplace_back(std::move(offsets));
89  allSizes.emplace_back(std::move(sizes));
90  }
91  FailureOr<TilingResult> tiledResult =
92  consumerOp.getTiledImplementationFromOperandTiles(
93  builder, consumerOperandNums, allOffsets, allSizes);
94  if (failed(tiledResult))
95  return failure();
96 
97  return *tiledResult;
98 }
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
Include the generated interface declarations.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.