MLIR  20.0.0git
Split.cpp
Go to the documentation of this file.
1 //===- Split.cpp - Structured op splitting --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/OpDefinition.h"
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 /// Creates a part of the given `op` split along the iteration space `dimension`
25 /// with the given `size` and an optional `offset` (default 0). Makes slices
26 /// of operands, using the input operands of the original op and the output
27 /// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
28 /// define the shape of the iteration space of the original op. Returns the
29 /// split-out op as well as the output operand values updated with the partial
30 /// results produced by this op through `results`.
31 static TilingInterface
32 createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
34  ValueRange resultOperands, unsigned dimension,
35  OpFoldResult size, OpFoldResult offset,
36  SmallVectorImpl<Value> &results) {
37  // Iteration space of the current part.
38  SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
39  SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
40  sizesCopy[dimension] = size;
41  offsetsCopy[dimension] = offset;
42 
43  // Create the part as it it were a single tile.
44  FailureOr<TilingResult> tilingResult =
45  op.getTiledImplementation(b, offsetsCopy, sizesCopy);
46 
47  // Insert the results back and populate the `results` list.
48  for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
49  SmallVector<OpFoldResult> resultOffsets, resultSizes;
50  if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
51  resultOffsets, resultSizes)))
52  return nullptr;
53  SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
54  b.getIndexAttr(1));
55  Value inserted = b.create<tensor::InsertSliceOp>(
56  loc, result, resultOperands[index], resultOffsets, resultSizes,
57  resultStrides);
58  results.push_back(inserted);
59  }
60  // TODO: this part can be generalized maybe to not expect a single op.
61  assert(tilingResult->tiledOps.size() == 1 &&
62  "expected split part to return a single tiled operation");
63  return cast<TilingInterface>(tilingResult->tiledOps[0]);
64 }
65 
66 std::pair<TilingInterface, TilingInterface>
67 linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
68  OpFoldResult splitPoint) {
69  // Compute the iteration space.
70  SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
71 
72  // Bail out on dimension overflow.
73  if (dimension >= iterationSpace.size())
74  return std::make_pair(op, TilingInterface());
75 
76  SmallVector<OpFoldResult> offsets = llvm::to_vector(llvm::map_range(
77  iterationSpace, [](const Range &range) { return range.offset; }));
78  SmallVector<OpFoldResult> sizes = llvm::to_vector(llvm::map_range(
79  iterationSpace, [](const Range &range) { return range.size; }));
80 
81  // Adjust the split point so that it doesn't overflow the size.
82  AffineExpr d0, d1, d2;
83  bindDims(rewriter.getContext(), d0, d1, d2);
85  rewriter, op.getLoc(),
87  rewriter.getContext())
88  .front(),
89  {splitPoint, offsets[dimension], sizes[dimension]});
90 
91  // Compute the size of the second part. Return early if the second part would
92  // have an empty iteration space.
94  rewriter, op.getLoc(), d0 + d1 - d2,
95  {iterationSpace[dimension].offset, iterationSpace[dimension].size,
96  minSplitPoint});
97  if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
98  if (cast<IntegerAttr>(attr).getValue().isZero())
99  return {op, TilingInterface()};
100  }
101 
102  // Compute destination tensors.
103  SmallVector<Value> destinationTensors;
104  LogicalResult destStatus = tensor::getOrCreateDestinations(
105  rewriter, op.getLoc(), op, destinationTensors);
106  (void)destStatus;
107  assert(succeeded(destStatus) && "failed to get destination tensors");
108 
109  // Create the first part.
110  SmallVector<Value> firstResults;
111  TilingInterface firstPart = createSplitPart(
112  rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
113  minSplitPoint, iterationSpace[dimension].offset, firstResults);
114 
115  // Need to pretend that the original op now takes as operands firstResults,
116  // otherwise tiling interface implementation will take the wrong value to
117  // produce data tiles.
118  rewriter.modifyOpInPlace(op, [&]() {
119  unsigned numTotalOperands = op->getNumOperands();
120  unsigned numOutputOperands = firstResults.size();
121  op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
122  firstResults);
123  });
124 
125  // Create the second part.
127  rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
128  SmallVector<Value> secondResults;
129  TilingInterface secondPart =
130  createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
131  dimension, remainingSize, totalOffset, secondResults);
132 
133  // Propagate any errors in part creation.
134  if (!firstPart || !secondPart)
135  return {TilingInterface(), TilingInterface()};
136 
137  // Replace the original op with the results of the two newly created ops.
138  rewriter.replaceOp(op, secondResults);
139  return {firstPart, secondPart};
140 }
static TilingInterface createSplitPart(RewriterBase &b, Location loc, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange resultOperands, unsigned dimension, OpFoldResult size, OpFoldResult offset, SmallVectorImpl< Value > &results)
Creates a part of the given op split along the iteration space dimension with the given size and an o...
Definition: Split.cpp:32
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1300
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:110
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
OpFoldResult offset