MLIR 22.0.0git
SwapExtractSliceWithProducerPatterns.cpp
Go to the documentation of this file.
1//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Swap a `tensor.extract_slice` with the producer of the source if the producer
10// implements the `TilingInterface`. When used in conjunction with tiling this
11// effectively tiles + fuses the producer with its consumer.
12//
13//===----------------------------------------------------------------------===//
14
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "tensor-swap-slices"
22
23using namespace mlir;
24
26 OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
27 auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
28 if (!producerOp)
29 return failure();
30
31 // `TilingInterface` currently only supports strides being 1.
32 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
33 return failure();
34
35 FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
36 builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
37 sliceOp.getMixedSizes());
38 if (failed(tiledResult))
39 return failure();
40
41 // For cases where the slice was rank-reducing, create a rank-reducing slice
42 // to get the same type back.
43 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
44 if (droppedDims.any()) {
45 assert(tiledResult->tiledValues.size() == 1 &&
46 "expected only a single tiled result value to replace the extract "
47 "slice");
48 SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
49 builder.getIndexAttr(0));
50 SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
51 builder.getIndexAttr(1));
52 auto newSliceOp = tensor::ExtractSliceOp::create(
53 builder, sliceOp.getLoc(), sliceOp.getType(),
54 tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
55 tiledResult->tiledValues[0] = newSliceOp;
56 }
57
58 return *tiledResult;
59}
60
63 ArrayRef<OpOperand *> consumerOperands) {
64 if (sliceOps.empty()) {
65 LLVM_DEBUG(
66 { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
67 return failure();
68 }
69 if (sliceOps.size() != consumerOperands.size()) {
70 LLVM_DEBUG({
71 llvm::dbgs()
72 << "expected as many operands as the number of slices passed";
73 });
74 return failure();
75 }
76 auto consumerOp =
77 dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
78 if (!consumerOp)
79 return failure();
80 for (auto *opOperand : consumerOperands.drop_front()) {
81 if (opOperand->getOwner() != consumerOp) {
82 LLVM_DEBUG({
83 llvm::dbgs()
84 << "expected all consumer operands to be from the same operation";
85 });
86 return failure();
87 }
88 }
89
90 auto consumerOperandNums = llvm::map_to_vector(
91 consumerOperands, [](OpOperand *opOperand) -> unsigned {
92 return opOperand->getOperandNumber();
93 });
96 for (auto sliceOp : sliceOps) {
97
98 // `TilingInterface` currently only supports strides being 1.
99 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
100 return failure();
101
102 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
103 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
104 allOffsets.emplace_back(std::move(offsets));
105 allSizes.emplace_back(std::move(sizes));
106 }
107 FailureOr<TilingResult> tiledResult =
108 consumerOp.getTiledImplementationFromOperandTiles(
109 builder, consumerOperandNums, allOffsets, allSizes);
110 if (failed(tiledResult))
111 return failure();
112
113 return *tiledResult;
114}
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
Include the generated interface declarations.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.