66#include "llvm/ADT/StringExtras.h"
67#include "llvm/Support/Debug.h"
71#define GEN_PASS_DEF_ACCLOOPTILING
72#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
76#define DEBUG_TYPE "acc-loop-tile"
82 ACCLoopTilingImpl(
MLIRContext *context, int32_t defaultTileSize,
85 defaultTileSize(defaultTileSize), accSupport(accSupport) {}
90 LogicalResult checkTileSizeTypes(acc::LoopOp loop,
92 auto ivTypes = loop.getBody().getArgumentTypes();
93 for (
size_t i = 0; i < tileSizes.size() && i < ivTypes.size(); ++i) {
94 Type tileType = tileSizes[i].getType();
95 Type ivType = ivTypes[i];
99 if (constVal && *constVal < 0)
103 auto tileIntType = dyn_cast<IntegerType>(tileType);
104 auto ivIntType = dyn_cast<IntegerType>(ivType);
105 if (tileIntType && ivIntType) {
106 if (tileIntType.getWidth() > ivIntType.getWidth()) {
107 accSupport.
emitNYI(loop.getLoc(),
108 "tile size type (i" +
109 std::to_string(tileIntType.getWidth()) +
110 ") is wider than loop IV type (i" +
111 std::to_string(ivIntType.getWidth()) +
")");
119 void emitTilingRemarks(acc::LoopOp loop,
ArrayRef<Value> tileSizes)
const {
124 auto getTileSizeStr = [&](
Value v) -> std::string {
127 if (name.empty() || name ==
"-1")
132 for (
Value v : tileSizes)
133 tileStrs.push_back(getTileSizeStr(v));
134 return "Tiling " + std::to_string(tileSizes.size()) +
135 "-level loop nest with tile(" + llvm::join(tileStrs,
",") +
142 for (
Value tileSize : tileSizes) {
144 if (val && *val < 0) {
148 return "Picking default tile size " +
149 std::to_string(defaultTileSize) +
150 " for unknown tile size '*'";
157 LogicalResult matchAndRewrite(acc::LoopOp origLoop,
160 if (origLoop.getTileValues().empty())
164 origLoop.getTileValues().end());
165 unsigned tileCount = tileSizes.size();
166 unsigned collapseCount = origLoop.getCollapseValue().value_or(1);
169 if (failed(checkTileSizeTypes(origLoop, tileSizes)))
174 emitTilingRemarks(origLoop, tileSizes);
176 LLVM_DEBUG(llvm::dbgs() <<
"\nBefore tiling:\n" << *origLoop <<
"\n");
180 origLoop.getTileOperandsMutable().clear();
181 origLoop.removeTileOperandsSegmentsAttr();
182 origLoop.removeTileOperandsDeviceTypeAttr();
186 if (collapseCount < tileCount) {
190 rewriter.
replaceOp(origLoop, loopsToTile[0]);
191 LLVM_DEBUG(llvm::dbgs() <<
"\nAfter uncollapsing:\n"
192 << *loopsToTile[0] <<
"\n");
194 loopsToTile.push_back(origLoop);
202 LLVM_DEBUG(llvm::dbgs() <<
"\nAfter tiling:\n " << *loopsToTile[0] <<
"\n");
207 int32_t defaultTileSize;
211class ACCLoopTiling :
public acc::impl::ACCLoopTilingBase<ACCLoopTiling> {
213 using ACCLoopTilingBase<ACCLoopTiling>::ACCLoopTilingBase;
215 void runOnOperation()
override {
216 func::FuncOp funcOp = getOperation();
221 patterns.insert<ACCLoopTilingImpl>(context, defaultTileSize, accSupport);
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
remark::detail::InFlightRemark emitRemark(Operation *op, std::function< std::string()> messageFn, llvm::StringRef category="openacc")
Emit an OpenACC remark with lazy message generation.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
std::string getVariableName(Value v)
Get the variable name for a given value.
mlir::acc::LoopOp tileACCLoops(llvm::SmallVector< mlir::acc::LoopOp > &tileLoops, const llvm::SmallVector< mlir::Value > &tileSizes, int32_t defaultTileSize, mlir::RewriterBase &rewriter)
Tile ACC loops according to the given tile sizes.
llvm::SmallVector< mlir::acc::LoopOp > uncollapseLoops(mlir::acc::LoopOp origLoop, unsigned tileCount, unsigned collapseCount, mlir::RewriterBase &rewriter)
Uncollapse tile loops with multiple IVs and collapseCount < tileCount.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...