20 #include "llvm/Support/Debug.h"
22 #define DEBUG_TYPE "tensor-swap-slices"
28 auto producerOp = dyn_cast<TilingInterface>(producer.
getOwner());
33 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
36 FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
38 sliceOp.getMixedSizes());
39 if (failed(tiledResult))
48 if (sliceOps.empty()) {
50 { llvm::dbgs() <<
"expected candidate slices list to be non-empty"; });
53 if (sliceOps.size() != consumerOperands.size()) {
56 <<
"expected as many operands as the number of slices passed";
61 dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
64 for (
auto opOperand : consumerOperands.drop_front()) {
65 if (opOperand->getOwner() != consumerOp) {
68 <<
"expected all consumer operands to be from the same operation";
74 auto consumerOperandNums = llvm::map_to_vector(
75 consumerOperands, [](
OpOperand *opOperand) ->
unsigned {
80 for (
auto sliceOp : sliceOps) {
83 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
88 allOffsets.emplace_back(std::move(offsets));
89 allSizes.emplace_back(std::move(sizes));
91 FailureOr<TilingResult> tiledResult =
92 consumerOp.getTiledImplementationFromOperandTiles(
93 builder, consumerOperandNums, allOffsets, allSizes);
94 if (failed(tiledResult))
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
Include the generated interface declarations.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.