MLIR  21.0.0git
SwapExtractSliceWithProducerPatterns.cpp
Go to the documentation of this file.
1 //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Swap a `tensor.extract_slice` with the producer of the source if the producer
10 // implements the `TilingInterface`. When used in conjunction with tiling this
11 // effectively tiles + fuses the producer with its consumer.
12 //
13 //===----------------------------------------------------------------------===//
14 
20 
21 using namespace mlir;
22 
24  OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
25  auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
26  if (!producerOp)
27  return failure();
28 
29  // `TilingInterface` currently only supports strides being 1.
30  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
31  return failure();
32 
33  FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
34  builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
35  sliceOp.getMixedSizes());
36  if (failed(tiledResult))
37  return failure();
38 
39  return *tiledResult;
40 }
41 
43  OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
44  OpOperand &consumer) {
45  auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
46  if (!consumerOp)
47  return failure();
48 
49  // `TilingInterface` currently only supports strides being 1.
50  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
51  return failure();
52 
53  FailureOr<TilingResult> tiledResult =
54  consumerOp.getTiledImplementationFromOperandTile(
55  builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
56  sliceOp.getMixedSizes());
57  if (failed(tiledResult))
58  return failure();
59 
60  return *tiledResult;
61 }
This class helps build Operations.
Definition: Builders.h:204
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.