MLIR 23.0.0git
Split.cpp
Go to the documentation of this file.
1//===- Split.cpp - Structured op splitting --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/AffineExpr.h"
13#include "mlir/IR/Attributes.h"
17#include "llvm/ADT/SmallVectorExtras.h"
18
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21
22using namespace mlir;
23using namespace mlir::linalg;
24
25/// Creates a part of the given `op` split along the iteration space `dimension`
26/// with the given `size` and an optional `offset` (default 0). Makes slices
27/// of operands, using the input operands of the original op and the output
28/// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
29/// define the shape of the iteration space of the original op. Returns the
30/// split-out op as well as the output operand values updated with the partial
31/// results produced by this op through `results`.
32static TilingInterface
33createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
35 ValueRange resultOperands, unsigned dimension,
36 OpFoldResult size, OpFoldResult offset,
37 SmallVectorImpl<Value> &results) {
38 // Iteration space of the current part.
39 SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
40 SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
41 sizesCopy[dimension] = size;
42 offsetsCopy[dimension] = offset;
43
44 // Create the part as if it were a single tile.
45 FailureOr<TilingResult> tilingResult =
46 op.getTiledImplementation(b, offsetsCopy, sizesCopy);
47
48 // Insert the results back and populate the `results` list.
49 for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
50 SmallVector<OpFoldResult> resultOffsets, resultSizes;
51 if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
52 resultOffsets, resultSizes)))
53 return nullptr;
54 SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
55 b.getIndexAttr(1));
56 Value inserted = tensor::InsertSliceOp::create(
57 b, loc, result, resultOperands[index], resultOffsets, resultSizes,
58 resultStrides);
59 results.push_back(inserted);
60 }
61 // TODO: this part can be generalized maybe to not expect a single op.
62 assert(tilingResult->tiledOps.size() == 1 &&
63 "expected split part to return a single tiled operation");
64 return cast<TilingInterface>(tilingResult->tiledOps[0]);
65}
66
67std::pair<TilingInterface, TilingInterface>
68linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
69 OpFoldResult splitPoint) {
70 // Compute the iteration space.
71 SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
72
73 // Bail out on dimension overflow.
74 if (dimension >= iterationSpace.size())
75 return std::make_pair(op, TilingInterface());
76
77 SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
78 iterationSpace, [](const Range &range) { return range.offset; });
79 SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
80 iterationSpace, [](const Range &range) { return range.size; });
81
82 // Adjust the split point so that it doesn't overflow the size.
83 AffineExpr d0, d1, d2;
84 bindDims(rewriter.getContext(), d0, d1, d2);
86 rewriter, op.getLoc(),
88 rewriter.getContext())
89 .front(),
90 {splitPoint, offsets[dimension], sizes[dimension]});
91
92 // Compute the size of the second part. Return early if the second part would
93 // have an empty iteration space.
95 rewriter, op.getLoc(), d0 + d1 - d2,
96 {iterationSpace[dimension].offset, iterationSpace[dimension].size,
97 minSplitPoint});
98 if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
99 if (cast<IntegerAttr>(attr).getValue().isZero())
100 return {op, TilingInterface()};
101 }
102
103 // Compute destination tensors.
104 SmallVector<Value> destinationTensors;
105 LogicalResult destStatus = tensor::getOrCreateDestinations(
106 rewriter, op.getLoc(), op, destinationTensors);
107 (void)destStatus;
108 assert(succeeded(destStatus) && "failed to get destination tensors");
109
110 // Create the first part.
111 SmallVector<Value> firstResults;
112 TilingInterface firstPart = createSplitPart(
113 rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
114 minSplitPoint, iterationSpace[dimension].offset, firstResults);
115
116 // Need to pretend that the original op now takes as operands firstResults,
117 // otherwise tiling interface implementation will take the wrong value to
118 // produce data tiles.
119 rewriter.modifyOpInPlace(op, [&]() {
120 unsigned numTotalOperands = op->getNumOperands();
121 unsigned numOutputOperands = firstResults.size();
122 op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
123 firstResults);
124 });
125
126 // Create the second part.
128 rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
129 SmallVector<Value> secondResults;
130 TilingInterface secondPart =
131 createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
132 dimension, remainingSize, totalOffset, secondResults);
133
134 // Propagate any errors in part creation.
135 if (!firstPart || !secondPart)
136 return {TilingInterface(), TilingInterface()};
137
138 // Replace the original op with the results of the two newly created ops.
139 rewriter.replaceOp(op, secondResults);
140 return {firstPart, secondPart};
141}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static TilingInterface createSplitPart(RewriterBase &b, Location loc, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange resultOperands, unsigned dimension, OpFoldResult size, OpFoldResult offset, SmallVectorImpl< Value > &results)
Creates a part of the given op split along the iteration space dimension with the given size and an o...
Definition Split.cpp:33
Base type for affine expression.
Definition AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents a single result from folding an operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition Split.cpp:68
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
OpFoldResult offset