MLIR  22.0.0git
SwapExtractSliceWithProducerPatterns.cpp
Go to the documentation of this file.
1 //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Swap a `tensor.extract_slice` with the producer of the source if the producer
10 // implements the `TilingInterface`. When used in conjunction with tiling this
11 // effectively tiles + fuses the producer with its consumer.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "tensor-swap-slices"
22 
23 using namespace mlir;
24 
26  OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
27  auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
28  if (!producerOp)
29  return failure();
30 
31  // `TilingInterface` currently only supports strides being 1.
32  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
33  return failure();
34 
35  FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
36  builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
37  sliceOp.getMixedSizes());
38  if (failed(tiledResult))
39  return failure();
40 
41  // For cases where the slice was rank-reducing, create a rank-reducing slice
42  // to get the same type back.
43  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
44  if (droppedDims.any()) {
45  assert(tiledResult->tiledValues.size() == 1 &&
46  "expected only a single tiled result value to replace the extract "
47  "slice");
48  SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
49  builder.getIndexAttr(0));
50  SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
51  builder.getIndexAttr(1));
52  auto newSliceOp = tensor::ExtractSliceOp::create(
53  builder, sliceOp.getLoc(), sliceOp.getType(),
54  tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
55  tiledResult->tiledValues[0] = newSliceOp;
56  }
57 
58  return *tiledResult;
59 }
60 
62  OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
63  ArrayRef<OpOperand *> consumerOperands) {
64  if (sliceOps.empty()) {
65  LLVM_DEBUG(
66  { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
67  return failure();
68  }
69  if (sliceOps.size() != consumerOperands.size()) {
70  LLVM_DEBUG({
71  llvm::dbgs()
72  << "expected as many operands as the number of slices passed";
73  });
74  return failure();
75  }
76  auto consumerOp =
77  dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
78  if (!consumerOp)
79  return failure();
80  for (auto opOperand : consumerOperands.drop_front()) {
81  if (opOperand->getOwner() != consumerOp) {
82  LLVM_DEBUG({
83  llvm::dbgs()
84  << "expected all consumer operands to be from the same operation";
85  });
86  return failure();
87  }
88  }
89 
90  auto consumerOperandNums = llvm::map_to_vector(
91  consumerOperands, [](OpOperand *opOperand) -> unsigned {
92  return opOperand->getOperandNumber();
93  });
96  for (auto sliceOp : sliceOps) {
97 
98  // `TilingInterface` currently only supports strides being 1.
99  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
100  return failure();
101 
102  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
103  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
104  allOffsets.emplace_back(std::move(offsets));
105  allSizes.emplace_back(std::move(sizes));
106  }
107  FailureOr<TilingResult> tiledResult =
108  consumerOp.getTiledImplementationFromOperandTiles(
109  builder, consumerOperandNums, allOffsets, allSizes);
110  if (failed(tiledResult))
111  return failure();
112 
113  return *tiledResult;
114 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
Include the generated interface declarations.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.