MLIR 23.0.0git
ACCLoopTiling.cpp
Go to the documentation of this file.
1//===- ACCLoopTiling.cpp - Tile ACC Loops ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the OpenACC loop tiling transformation for acc.loop
10// operations that have the tile clause (OpenACC 3.4 spec, section 2.9.8).
11//
12// Overview:
13// ---------
14// The tile clause specifies that the iterations of the associated loops should
15// be divided into tiles (rectangular blocks). This pass transforms a single
16// or nested acc.loop with tile clauses into a structure of "tile loops"
17// (iterating over tiles) containing "element loops" (iterating within tiles).
18//
19// For example, tiling a 2-level nested loop with tile(T1, T2) produces:
20//
21// // Before tiling:
22// acc.loop tile(T1, T2) control(%i, %j) = (lb1, lb2) to (ub1, ub2) step (s1,
23// s2)
24//
25// // After tiling:
26// acc.loop control(%i) = (lb1) to (ub1) step (s1*T1) { // tile loop 1
27// acc.loop control(%j) = (lb2) to (ub2) step (s2*T2) { // tile loop 2
28// acc.loop control(%ii) = (%i) to (min(ub1, %i+s1*T1)) step (s1) { //
29// element 1
30// acc.loop control(%jj) = (%j) to (min(ub2, %j+s2*T2)) step (s2) { //
31// element 2
32// // loop body using %ii, %jj
33// }
34// }
35// }
36// }
37//
38// Gang/worker/vector attributes are distributed as follows:
39// - gang: applied to tile loops
40// - vector: applied to element loops
41// - worker: removed from inner loops
42//
43// Unknown Tile Sizes:
44// -------------------
45// The OpenACC tile(*) syntax indicates an implementation-defined tile size.
46// In the IR, this is represented as -1. The pass resolves these to the
47// default tile size (configurable via pass option).
48//
49// Requirements:
50// -------------
51// 1. The pass uses the OpenACCSupport analysis for remark and NYI (not yet
52// implemented) emission. Custom implementations can be registered via
53// setImplementation() to provide pipeline-specific handling.
54//
55//===----------------------------------------------------------------------===//
56
64#include "mlir/Support/LLVM.h"
66#include "llvm/ADT/StringExtras.h"
67#include "llvm/Support/Debug.h"
68
69namespace mlir {
70namespace acc {
71#define GEN_PASS_DEF_ACCLOOPTILING
72#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
73} // namespace acc
74} // namespace mlir
75
76#define DEBUG_TYPE "acc-loop-tile"
77
78namespace {
79using namespace mlir;
80
81struct ACCLoopTilingImpl : public OpRewritePattern<acc::LoopOp> {
82 ACCLoopTilingImpl(MLIRContext *context, int32_t defaultTileSize,
83 acc::OpenACCSupport &accSupport)
85 defaultTileSize(defaultTileSize), accSupport(accSupport) {}
86
87 // Check that tile size types are not narrower than IV types.
88 // We only check when both types are IntegerType. For IndexType, the width
89 // is target-dependent and the casting utility will handle it correctly.
90 LogicalResult checkTileSizeTypes(acc::LoopOp loop,
91 ArrayRef<Value> tileSizes) const {
92 auto ivTypes = loop.getBody().getArgumentTypes();
93 for (size_t i = 0; i < tileSizes.size() && i < ivTypes.size(); ++i) {
94 Type tileType = tileSizes[i].getType();
95 Type ivType = ivTypes[i];
96
97 // Skip unknown tile sizes (will be created with correct type)
98 auto constVal = getConstantIntValue(tileSizes[i]);
99 if (constVal && *constVal < 0)
100 continue;
101
102 // Only compare when both are integer types (not index)
103 auto tileIntType = dyn_cast<IntegerType>(tileType);
104 auto ivIntType = dyn_cast<IntegerType>(ivType);
105 if (tileIntType && ivIntType) {
106 if (tileIntType.getWidth() > ivIntType.getWidth()) {
107 accSupport.emitNYI(loop.getLoc(),
108 "tile size type (i" +
109 std::to_string(tileIntType.getWidth()) +
110 ") is wider than loop IV type (i" +
111 std::to_string(ivIntType.getWidth()) + ")");
112 return failure();
113 }
114 }
115 }
116 return success();
117 }
118
119 void emitTilingRemarks(acc::LoopOp loop, ArrayRef<Value> tileSizes) const {
120 // Emit remarks for loop tiling
121 accSupport.emitRemark(
122 loop,
123 [&]() {
124 auto getTileSizeStr = [&](Value v) -> std::string {
125 std::string name = accSupport.getVariableName(v);
126 // Use "*" for unknown tile sizes (represented as -1 or empty)
127 if (name.empty() || name == "-1")
128 return "*";
129 return name;
130 };
132 for (Value v : tileSizes)
133 tileStrs.push_back(getTileSizeStr(v));
134 return "Tiling " + std::to_string(tileSizes.size()) +
135 "-level loop nest with tile(" + llvm::join(tileStrs, ",") +
136 ")";
137 },
138 DEBUG_TYPE);
139
140 // Emit remarks for unknown tile sizes that will be resolved to default
141 // TODO: Need to base the default tile size on some heuristics.
142 for (Value tileSize : tileSizes) {
143 std::optional<int64_t> val = getConstantIntValue(tileSize);
144 if (val && *val < 0) {
145 accSupport.emitRemark(
146 loop,
147 [&]() {
148 return "Picking default tile size " +
149 std::to_string(defaultTileSize) +
150 " for unknown tile size '*'";
151 },
152 DEBUG_TYPE);
153 }
154 }
155 }
156
157 LogicalResult matchAndRewrite(acc::LoopOp origLoop,
158 PatternRewriter &rewriter) const override {
159
160 if (origLoop.getTileValues().empty())
161 return success();
162
163 SmallVector<Value> tileSizes(origLoop.getTileValues().begin(),
164 origLoop.getTileValues().end());
165 unsigned tileCount = tileSizes.size();
166 unsigned collapseCount = origLoop.getCollapseValue().value_or(1);
167
168 // Sanity check tile size types
169 if (failed(checkTileSizeTypes(origLoop, tileSizes)))
170 return failure();
171
172 // Emit remarks for loop tiling. This is emitted before the original loop
173 // is modified. However, it assumes that tiling will not fail.
174 emitTilingRemarks(origLoop, tileSizes);
175
176 LLVM_DEBUG(llvm::dbgs() << "\nBefore tiling:\n" << *origLoop << "\n");
177
178 // Clear tile operands from origLoop
179 rewriter.startOpModification(origLoop);
180 origLoop.getTileOperandsMutable().clear();
181 origLoop.removeTileOperandsSegmentsAttr();
182 origLoop.removeTileOperandsDeviceTypeAttr();
183 rewriter.finalizeOpModification(origLoop);
184
185 SmallVector<acc::LoopOp> loopsToTile;
186 if (collapseCount < tileCount) {
187 // Uncollapse tile loops before tiling if necessary
188 loopsToTile =
189 acc::uncollapseLoops(origLoop, tileCount, collapseCount, rewriter);
190 rewriter.replaceOp(origLoop, loopsToTile[0]);
191 LLVM_DEBUG(llvm::dbgs() << "\nAfter uncollapsing:\n"
192 << *loopsToTile[0] << "\n");
193 } else {
194 loopsToTile.push_back(origLoop);
195 }
196
197 // loopsToTile is a vector of perfectly nested loops. The outermost loop
198 // may have multiple IVs but inner loops can only have one IV.
199 // The utility handles unknown tile sizes (*) by using `defaultTileSize`.
200 acc::tileACCLoops(loopsToTile, tileSizes, defaultTileSize, rewriter);
201
202 LLVM_DEBUG(llvm::dbgs() << "\nAfter tiling:\n " << *loopsToTile[0] << "\n");
203 return success();
204 }
205
206private:
207 int32_t defaultTileSize;
208 acc::OpenACCSupport &accSupport;
209};
210
211class ACCLoopTiling : public acc::impl::ACCLoopTilingBase<ACCLoopTiling> {
212public:
213 using ACCLoopTilingBase<ACCLoopTiling>::ACCLoopTilingBase;
214
215 void runOnOperation() override {
216 func::FuncOp funcOp = getOperation();
217 MLIRContext *context = funcOp.getContext();
218 acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
219
221 patterns.insert<ACCLoopTilingImpl>(context, defaultTileSize, accSupport);
223 grc.setUseTopDownTraversal(true);
224 grc.setMaxIterations(1);
225 (void)applyPatternsGreedily(funcOp, std::move(patterns), grc);
226 }
227};
228
229} // namespace
return success()
#define DEBUG_TYPE
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
remark::detail::InFlightRemark emitRemark(Operation *op, std::function< std::string()> messageFn, llvm::StringRef category="openacc")
Emit an OpenACC remark with lazy message generation.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
std::string getVariableName(Value v)
Get the variable name for a given value.
mlir::acc::LoopOp tileACCLoops(llvm::SmallVector< mlir::acc::LoopOp > &tileLoops, const llvm::SmallVector< mlir::Value > &tileSizes, int32_t defaultTileSize, mlir::RewriterBase &rewriter)
Tile ACC loops according to the given tile sizes.
llvm::SmallVector< mlir::acc::LoopOp > uncollapseLoops(mlir::acc::LoopOp origLoop, unsigned tileCount, unsigned collapseCount, mlir::RewriterBase &rewriter)
Uncollapse tile loops with multiple IVs and collapseCount < tileCount.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...