26 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "tile-using-interface"
43 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
55 size_t iterationDomainSize) {
57 if (filledVector.size() < iterationDomainSize) {
58 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
59 filledVector.append(range.begin(), range.end());
61 if (filledVector.size() > iterationDomainSize)
62 filledVector.resize(iterationDomainSize);
81 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
90 if (ts && ts.value() == 1)
120 assert(!loopRanges.empty() &&
"expected at least one loop range");
121 assert(loopRanges.size() == tileSizeVals.size() &&
122 "expected as many tile sizes as loop ranges");
125 offsets.resize(loopRanges.size());
126 sizes.resize(loopRanges.size());
133 Value tileSize = tileSizeVals[loopRange.index()];
137 offsets[loopRange.index()] = offset;
138 sizes[loopRange.index()] = size;
142 auto loop = builder.
create<scf::ForOp>(
147 bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
148 builder.
create<scf::YieldOp>(loc);
150 offsets[loopRange.index()] = loop.getInductionVar();
151 loops.push_back(loop);
191 tileOffsetsList[yieldedValue.index()];
196 loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
197 tileOffsets, tileSizes, tileStrides);
198 inserts.push_back(insert);
207 rewriter.
eraseOp(loop.value());
208 loops[loop.index()] = newLoops[loop.index()];
210 return llvm::to_vector(llvm::map_range(
211 loops.front().getResults().take_back(yieldedValues.size()),
243 for (
const auto &destValue :
llvm::enumerate(tiledOpDestinationValues)) {
244 auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
247 sliceOp.setOperand(0, bbArgsList[destValue.index()]);
262 tileOffsetsList, tileSizesList, loops);
263 for (
auto tiledOp : tilingResult.
tiledOps) {
264 if (
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
265 auto innerMostLoop = loops.back();
268 innerMostLoop.getRegionIterArgs());
282 if (!
options.tileSizeComputationFunction) {
284 op,
"missing tile size computation function");
289 size_t numLoops = iterationDomain.size();
292 op,
"unable to tile op with no iteration domain");
300 options.tileSizeComputationFunction(rewriter, op);
301 if (tileSizeVector.size() < iterationDomain.size()) {
303 tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
312 if (!
options.interchangeVector.empty()) {
314 iterationDomain.size());
316 if (!interchangeVector.empty()) {
319 op,
"invalid intechange vector, not a permutation of the entire "
331 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
333 if (!interchangeVector.empty()) {
341 if (!tilingResult.
loops.empty()) {
342 llvm::dbgs() <<
"LoopNest shell :\n";
343 tilingResult.loops.front().dump();
344 llvm::dbgs() <<
"\n";
349 if (!tilingResult.
loops.empty())
351 tilingResult.
loops.back().getBody()->getTerminator());
353 op.getTiledImplementation(rewriter, offsets, sizes);
354 tilingResult.
tiledOps.append(tiledImplementation->tiledOps);
355 if (op->getNumResults() == 0) {
362 if (tilingResult.
loops.empty()) {
363 tilingResult.
replacements = tiledImplementation->tiledValues;
370 int64_t numResults = op->getNumResults();
372 resultSizesList(numResults);
374 if (
failed(op.getResultTilePosition(rewriter, result.index(), offsets,
376 resultOffsetsList[result.index()],
377 resultSizesList[result.index()]))) {
379 op,
"failed to get slice of result produced");
385 destinationTensors)))
389 rewriter, destinationTensors, tiledImplementation.value(),
390 resultOffsetsList, resultSizesList, tilingResult.
loops);
393 if (!tilingResult.
loops.empty()) {
394 llvm::dbgs() <<
"After tiled implementation :\n";
395 tilingResult.loops.front().dump();
396 llvm::dbgs() <<
"\n";
404 PartialReductionOpInterface op,
409 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
413 if (tileSizeVector.size() < iterationDomain.size()) {
415 tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
417 if (op->getNumResults() != 1)
419 op,
"don't support ops with multiple results for now");
421 tilingInterfaceOp.getLoopIteratorTypes();
422 int64_t numReductionDims = llvm::count(
423 tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
424 if (numReductionDims != 1)
426 op,
"only support ops with one reduction dimension.");
428 for (
auto [idx, iteratorType] :
430 if (iteratorType == utils::IteratorType::reduction) {
435 if (
static_cast<size_t>(reductionDim) >= tileSize.size())
440 op.generateInitialTensorForPartialReduction(b, loc, tileSize,
442 if (
failed(identityTensor))
444 "cannot create a tensor of identity value.");
448 b, loc, iterationDomain, tileSizeVector, offsets, sizes);
452 Operation *parallelOp = op.tileToPartialReduction(
453 b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
456 for (
size_t i = 0; i < offsets.size(); i++)
457 resultSizesList.push_back(
461 b, (*identityTensor)->getResults(), parallelOp->
getResults(), outOffsets,
462 resultSizesList, loops);
464 auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
465 auto innerMostLoop = loops.back();
467 assert(destinationTensors.size() ==
468 innerMostLoop.getRegionIterArgs().size() &&
469 "unexpected number of outputs");
471 innerMostLoop.getRegionIterArgs());
475 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
480 results.
loops = std::move(loops);
494 static std::tuple<OpResult, std::optional<OpOperand *>>
497 std::optional<OpOperand *> destinationIterArg;
498 auto loopIt = loops.rbegin();
500 scf::ForOp loop = *loopIt;
501 if (iterArg.getOwner()->getParentOp() != loop)
503 source = &loop.getOpOperandForRegionIterArg(iterArg);
506 if (loopIt == loops.rend())
507 destinationIterArg = source;
513 std::optional<scf::SCFFuseProducerOfSliceResult>
515 tensor::ExtractSliceOp candidateSliceOp,
519 auto [fusableProducer, destinationIterArg] =
522 if (!fusableProducer)
531 if (
failed(tileAndFuseResult))
534 tileAndFuseResult->tiledValues[0]);
585 scf::ForOp outerMostLoop = loops.front();
586 std::optional<unsigned> iterArgNumber;
587 if (destinationIterArg) {
589 outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
592 int64_t resultNumber = fusableProducer.getResultNumber();
594 dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
595 outerMostLoop.setIterArg(iterArgNumber.value(),
596 dstOp.getTiedOpOperand(fusableProducer)->get());
598 for (
auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
599 auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
602 scf::ForOp innerMostLoop = loops.back();
604 rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
605 innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
609 tileAndFuseResult->tiledValues[0],
610 tileAndFuseResult->tiledOps};
618 auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
622 rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
628 resultOffsets, resultSizes, loops);
630 for (
auto tileAndFusedOp : tileAndFusedOps) {
631 auto dstStyleProducer =
632 dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
633 if (!dstStyleProducer)
636 dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
639 rewriter, dstValue, loops.back().getRegionIterArgs().back());
650 if (!consumer->getNumResults()) {
652 consumer,
"invalid pattern for op with no results");
657 llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
663 for (
auto *tiledOp : tilingResult->tiledOps)
665 tileAndFuseResult.
loops = std::move(tilingResult->loops);
667 llvm::zip(consumer->getResults(), tilingResult->replacements))) {
668 tileAndFuseResult.
replacements[std::get<0>(result.value())] =
669 std::get<1>(result.value());
670 yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
671 result.index())] = result.index();
676 if (tileAndFuseResult.
loops.empty())
677 return tileAndFuseResult;
686 auto addCandidateSlices = [](
Operation *fusedOp,
687 std::deque<tensor::ExtractSliceOp> &candidates) {
689 if (
auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
690 candidates.push_back(sliceOp);
693 std::deque<tensor::ExtractSliceOp> candidates;
696 while (!candidates.empty()) {
698 tensor::ExtractSliceOp candidateSliceOp = candidates.front();
699 candidates.pop_front();
704 std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
706 tileAndFuseResult.
loops);
711 fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
713 addCandidateSlices(tiledAndFusedOp, candidates);
716 return tileAndFuseResult;
725 TilingInterface op) {
727 if (op->getNumResults() > 0) {
729 op,
"unable to lower to loops operations with return values");
736 for (
auto loopRange : domain) {
743 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
745 loops.push_back(loop);
746 ivs.push_back(loop.getInductionVar());
749 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static void updateDestinationOperandsForTiledOp(OpBuilder &builder, ValueRange tiledOpDestinationValues, ValueRange bbArgsList)
If the tiled operation is destination passing style, update the slice of the destination used (which ...
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, Value tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
static SmallVector< Value > yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, ValueRange yieldedValues, ArrayRef< SmallVector< OpFoldResult >> tileOffsetsList, ArrayRef< SmallVector< OpFoldResult >> tileSizesList, MutableArrayRef< scf::ForOp > loops)
For a value to be yielded (yieldedValue) from within a loop nest loops, construct the destructive upd...
static SmallVector< scf::ForOp > generateTileLoopNest(OpBuilder &builder, Location loc, ArrayRef< Range > loopRanges, ArrayRef< Value > tileSizeVals, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Generate an empty loop nest that represents the tiled loop nest shell.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< scf::ForOp > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< scf::ForOp > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
void yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< scf::ForOp > loops)
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Pattern to swap an tensor.extract_slice with its producer when the producer implements the TilingInte...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
This header declares functions that assit transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBBArgs)> NewYieldValueFn
Replace the loop with newIterOperands added as new initialization values.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.
SmallVector< Value > tiledValues
SmallVector< Operation * > tiledOps
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
Operation * parallelTiledOp
The partial reduction tiled op generated.
Operation * mergeOp
The final reduction operation merging all the partial reductions.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
Operation * initialOp
Initial op.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
llvm::DenseMap< Value, Value > replacements
The replacement values to use for the tiled and fused operations.
llvm::SetVector< Operation * > tiledAndFusedOps
List of tiled and fused operations generated.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
SmallVector< Value > replacements
Values to use as replacements for the untiled op.