24 #include "llvm/Support/Debug.h" 26 #define DEBUG_TYPE "tile-using-interface" 37 &op->getParentOfType<func::FuncOp>().getBody().front());
38 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
50 size_t iterationDomainSize) {
51 SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
52 if (filledVector.size() < iterationDomainSize) {
53 auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
54 filledVector.append(range.begin(), range.end());
56 if (filledVector.size() > iterationDomainSize)
57 filledVector.resize(iterationDomainSize);
64 ArrayRef<unsigned> interchange) {
65 assert(interchange.size() == vector.size());
66 return llvm::to_vector(
67 llvm::map_range(interchange, [&](
unsigned val) {
return vector[val]; }));
70 static SmallVector<unsigned>
72 SmallVector<unsigned> inversion(interchange.size());
74 inversion[pos.value()] = pos.index();
80 llvm::SmallDenseSet<unsigned, 4> seenVals;
81 for (
auto val : interchange) {
82 if (seenVals.count(val))
86 return seenVals.size() == interchange.size();
104 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
113 static SmallVector<scf::ForOp>
115 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
116 SmallVector<OpFoldResult> &offsets,
117 SmallVector<OpFoldResult> &sizes) {
118 assert(!loopRanges.empty() &&
"expected at least one loop range");
119 assert(loopRanges.size() == tileSizeVals.size() &&
120 "expected as many tile sizes as loop ranges");
122 SmallVector<scf::ForOp> loops;
123 offsets.resize(loopRanges.size());
124 sizes.resize(loopRanges.size());
142 offsets[loopRange.index()] = offset;
143 sizes[loopRange.index()] = size;
147 auto loop = builder.
create<scf::ForOp>(
148 loc, offset, size, tileSizeVals[loopRange.index()],
ValueRange{},
152 Range{loopRange.value().
offset, loopRange.value().size,
153 tileSizeVals[loopRange.index()]});
154 Value boundedTileSize =
156 ? tileSizeVals[loopRange.index()]
157 : builder.
create<AffineMinOp>(
159 ValueRange{iv, tileSizeVals[loopRange.index()], size});
160 sizes[loopRange.index()] = boundedTileSize;
161 builder.
create<scf::YieldOp>(loc);
163 offsets[loopRange.index()] = loop.getInductionVar();
164 loops.push_back(loop);
174 options(std::move(options)) {}
181 options(std::move(options)) {}
191 op,
"missing tile size computation function");
196 size_t numLoops = iterationDomain.size();
199 op,
"unable to tile op with no iteration domain");
208 if (tileSizeVector.size() < iterationDomain.size()) {
210 tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
221 iterationDomain.size());
223 if (!interchangeVector.empty()) {
226 op,
"invalid intechange vector, not a permutation of the entire " 240 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
242 if (!interchangeVector.empty()) {
249 if (!tilingResult.
loops.empty()) {
250 llvm::errs() <<
"LoopNest shell :\n";
251 tilingResult.
loops.front().dump();
252 llvm::errs() <<
"\n";
257 if (!tilingResult.
loops.empty())
259 tilingResult.
loops.back().getBody()->getTerminator());
261 rewriter, op.getDestinationOperands(rewriter), offsets, sizes,
true);
262 if (tiledImplementation.size() != 1) {
264 op,
"expected tiled implementation to return a single op");
266 tilingResult.
tiledOp = tiledImplementation[0];
269 if (!tilingResult.
loops.empty()) {
270 llvm::errs() <<
"After tiled implementation :\n";
271 tilingResult.
loops.front().dump();
272 llvm::errs() <<
"\n";
277 if (op->getNumResults() == 0) {
285 if (tilingResult.
loops.empty()) {
317 for (
auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
319 if (
failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
322 op.emitOpError(
"unable to get position of result ")
323 << resultNum <<
" of the tiled implementation";
328 Value yieldedValue = b.
create<tensor::InsertSliceOp>(
330 newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
332 yieldedValues.push_back(yieldedValue);
334 return yieldedValues;
337 rewriter, tilingResult.
loops, op.getDestinationOperands(rewriter),
340 rewriter.
eraseOp(loop.value());
341 tilingResult.
loops[loop.index()] = newLoops[loop.index()];
356 tilingPattern(context, std::move(options)) {}
364 tilingPattern(context, std::move(options)) {}
371 auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
374 v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
385 assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
386 "expect same number of iter args");
387 Block *block = &(*innerFor.getRegion().begin());
389 llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
390 Value source = std::get<0>(it);
391 Value target = std::get<1>(it);
403 if (!op->getNumResults()) {
405 op,
"invalid pattern for op with no results");
413 if (
failed(tilingResult)) {
417 tileAndFuseResult.
loops = std::move(tilingResult->loops);
427 auto addCandidateSlices = [](
Operation *fusedOp,
428 std::deque<tensor::ExtractSliceOp> &candidates) {
430 if (
auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
431 candidates.push_back(sliceOp);
434 std::deque<tensor::ExtractSliceOp> candidates;
437 while (!candidates.empty()) {
439 tensor::ExtractSliceOp candidateSliceOp = candidates.front();
440 candidates.pop_front();
446 if (!fusableProducer)
453 fusableProducer.value());
454 if (
failed(fusedProducerValue))
456 rewriter.
replaceOp(candidateSliceOp, fusedProducerValue.value());
461 Operation *fusedProducer = fusedProducerValue->getDefiningOp();
463 addCandidateSlices(fusedProducer, candidates);
499 TilingInterface unfusedProducerOp =
500 cast<TilingInterface>(fusableProducer->getOwner());
501 scf::ForOp outerMostTiledLoop = tileAndFuseResult.
loops.front();
503 unfusedProducerOp.getDestinationOperands(rewriter);
504 for (
OpOperand &uses : unfusedProducerOp->getUses()) {
505 if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
506 unsigned resultNumber = uses.get().cast<
OpResult>().getResultNumber();
507 unsigned operandNumber = uses.getOperandNumber();
508 outerMostTiledLoop->setOperand(
509 operandNumber, unfusedProducerOpDestValues[resultNumber]);
514 tileAndFuseResult.
loops.back(), rewriter);
515 return tileAndFuseResult;
528 if (op->getNumResults() > 0) {
530 op,
"unable to lower to loops operations with return values");
536 for (
auto loopRange : domain) {
543 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
545 loops.push_back(loop);
546 ivs.push_back(loop.getInductionVar());
549 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
Include the generated interface declarations.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, NewYieldValueFn newYieldValueFn)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Operation * > tiledAndFusedOps
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
MLIRContext * getContext() const
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static SmallVector< unsigned > invertPermutationVector(ArrayRef< unsigned > interchange)
Helper method to apply to invert a permutation.
Operation is a basic unit of execution within MLIR.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
FailureOr< SCFTilingResult > returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
This is a value defined by a result of an operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
FailureOr< Value > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Pattern to swap an tensor.extract_slice with its producer when the producer implements the TilingInte...
Block represents an ordered list of Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Pattern to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
SmallVector< scf::ForOp > loops
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
static bool tileDividesIterationDomain(Range loopRange)
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Block * getBlock()
Returns the operation block that contains this operation.
static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, PatternRewriter &rewriter)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides support for representing a failure result, or a valid value of type T...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
SmallVector< unsigned > interchangeVector
The interchange vector to reorder the tiled loops.
FailureOr< SCFTileAndFuseResult > returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Attributes are known-constant values of operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Base type for affine expression.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options, PatternBenefit benefit=1)
Construct a generic pattern applied to all TilingInterface ops.
This class represents an argument of a Block.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static llvm::ManagedStatic< PassManagerOptions > options
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Options to use to control tiling.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
SmallVector< scf::ForOp > loops
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
static SmallVector< unsigned > fillInterchangeVector(ArrayRef< unsigned > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
SCFTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
static bool isPermutation(ArrayRef< unsigned > interchange)
Method to check if an interchange vector is a permutation.
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, SCFTilingOptions options, PatternBenefit benefit=1)
Construct a generic pattern applied to all TilingInterface ops.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
static SmallVector< scf::ForOp > generateTileLoopNest(OpBuilder &builder, Location loc, ArrayRef< Range > loopRanges, ArrayRef< Value > tileSizeVals, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Generate an empty loop nest that represents the tiled loop nest shell.
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
result_range getResults()
This class helps build Operations.
FailureOr< SmallVector< scf::ForOp > > returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBBArgs)> NewYieldValueFn
Replace the loop with newIterOperands added as new initialization values.
static Optional< OpResult > getFusableProducer(Value v)
Return the Value that is defined by an operation that implements the TilingInterface.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...