MLIR 22.0.0git
ACCLoopTiling.cpp
Go to the documentation of this file.
1//===- ACCLoopTiling.cpp - Tile ACC Loops ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the OpenACC loop tiling transformation for acc.loop
10// operations that have the tile clause (OpenACC 3.4 spec, section 2.9.8).
11//
12// Overview:
13// ---------
14// The tile clause specifies that the iterations of the associated loops should
15// be divided into tiles (rectangular blocks). This pass transforms a single
16// or nested acc.loop with tile clauses into a structure of "tile loops"
17// (iterating over tiles) containing "element loops" (iterating within tiles).
18//
19// For example, tiling a 2-level nested loop with tile(T1, T2) produces:
20//
21// // Before tiling:
22// acc.loop tile(T1, T2) control(%i, %j) = (lb1, lb2) to (ub1, ub2) step (s1,
23// s2)
24//
25// // After tiling:
26// acc.loop control(%i) = (lb1) to (ub1) step (s1*T1) { // tile loop 1
27// acc.loop control(%j) = (lb2) to (ub2) step (s2*T2) { // tile loop 2
28// acc.loop control(%ii) = (%i) to (min(ub1, %i+s1*T1)) step (s1) { //
29// element 1
30// acc.loop control(%jj) = (%j) to (min(ub2, %j+s2*T2)) step (s2) { //
31// element 2
32// // loop body using %ii, %jj
33// }
34// }
35// }
36// }
37//
38// Gang/worker/vector attributes are distributed as follows:
39// - gang: applied to tile loops
40// - vector: applied to element loops
41// - worker: removed from inner loops
42//
43// Unknown Tile Sizes:
44// -------------------
45// The OpenACC tile(*) syntax indicates an implementation-defined tile size.
46// In the IR, this is represented as -1. The pass resolves these to the
47// default tile size (configurable via pass option).
48//
49// Requirements:
50// -------------
51// 1. The pass uses the OpenACCSupport analysis for remark and NYI (not yet
52// implemented) emission. Custom implementations can be registered via
53// setImplementation() to provide pipeline-specific handling.
54//
55//===----------------------------------------------------------------------===//
56
64#include "mlir/Support/LLVM.h"
66#include "llvm/Support/Debug.h"
67
68namespace mlir {
69namespace acc {
70#define GEN_PASS_DEF_ACCLOOPTILING
71#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
72} // namespace acc
73} // namespace mlir
74
75#define DEBUG_TYPE "acc-loop-tile"
76
77namespace {
78using namespace mlir;
79
80struct ACCLoopTilingImpl : public OpRewritePattern<acc::LoopOp> {
81 ACCLoopTilingImpl(MLIRContext *context, int32_t defaultTileSize,
82 acc::OpenACCSupport &accSupport)
84 defaultTileSize(defaultTileSize), accSupport(accSupport) {}
85
86 // Check that tile size types are not narrower than IV types.
87 // We only check when both types are IntegerType. For IndexType, the width
88 // is target-dependent and the casting utility will handle it correctly.
89 LogicalResult checkTileSizeTypes(acc::LoopOp loop,
90 ArrayRef<Value> tileSizes) const {
91 auto ivTypes = loop.getBody().getArgumentTypes();
92 for (size_t i = 0; i < tileSizes.size() && i < ivTypes.size(); ++i) {
93 Type tileType = tileSizes[i].getType();
94 Type ivType = ivTypes[i];
95
96 // Skip unknown tile sizes (will be created with correct type)
97 auto constVal = getConstantIntValue(tileSizes[i]);
98 if (constVal && *constVal < 0)
99 continue;
100
101 // Only compare when both are integer types (not index)
102 auto tileIntType = dyn_cast<IntegerType>(tileType);
103 auto ivIntType = dyn_cast<IntegerType>(ivType);
104 if (tileIntType && ivIntType) {
105 if (tileIntType.getWidth() > ivIntType.getWidth()) {
106 accSupport.emitNYI(loop.getLoc(),
107 "tile size type (i" +
108 std::to_string(tileIntType.getWidth()) +
109 ") is wider than loop IV type (i" +
110 std::to_string(ivIntType.getWidth()) + ")");
111 return failure();
112 }
113 }
114 }
115 return success();
116 }
117
118 void emitTilingRemarks(acc::LoopOp loop, ArrayRef<Value> tileSizes) const {
119 // Emit remarks for loop tiling
120 size_t tileLevel = tileSizes.size();
121 std::string msg =
122 "Tiling " + std::to_string(tileLevel) + "-level loop nest with tile(";
123 for (size_t i = 0; i < tileSizes.size(); ++i) {
124 std::optional<int64_t> val = getConstantIntValue(tileSizes[i]);
125 if (*val == -1)
126 msg += "*";
127 else
128 msg += std::to_string(*val);
129 if (i < tileSizes.size() - 1)
130 msg += ",";
131 }
132 msg += ")";
133 accSupport.emitRemark(loop, llvm::Twine(msg), DEBUG_TYPE);
134
135 // Emit remarks for unknown tile sizes that will be resolved to default
136 // TODO: Need to base the default tile size on some heuristics.
137 for (Value tileSize : tileSizes) {
138 std::optional<int64_t> val = getConstantIntValue(tileSize);
139 if (val && *val < 0) {
140 std::string unknownMsg = "Picking default tile size " +
141 std::to_string(defaultTileSize) +
142 " for unknown tile size '*'";
143 accSupport.emitRemark(loop, llvm::Twine(unknownMsg), DEBUG_TYPE);
144 }
145 }
146 }
147
148 LogicalResult matchAndRewrite(acc::LoopOp origLoop,
149 PatternRewriter &rewriter) const override {
150
151 if (origLoop.getTileValues().empty())
152 return success();
153
154 SmallVector<Value> tileSizes(origLoop.getTileValues().begin(),
155 origLoop.getTileValues().end());
156 unsigned tileCount = tileSizes.size();
157 unsigned collapseCount = origLoop.getCollapseValue().value_or(1);
158
159 // Sanity check tile size types
160 if (failed(checkTileSizeTypes(origLoop, tileSizes)))
161 return failure();
162
163 // Emit remarks for loop tiling. This is emitted before the original loop
164 // is modified. However, it assumes that tiling will not fail.
165 emitTilingRemarks(origLoop, tileSizes);
166
167 LLVM_DEBUG(llvm::dbgs() << "\nBefore tiling:\n" << *origLoop << "\n");
168
169 // Clear tile operands from origLoop
170 rewriter.startOpModification(origLoop);
171 origLoop.getTileOperandsMutable().clear();
172 origLoop.removeTileOperandsSegmentsAttr();
173 origLoop.removeTileOperandsDeviceTypeAttr();
174 rewriter.finalizeOpModification(origLoop);
175
176 SmallVector<acc::LoopOp> loopsToTile;
177 if (collapseCount < tileCount) {
178 // Uncollapse tile loops before tiling if necessary
179 loopsToTile =
180 acc::uncollapseLoops(origLoop, tileCount, collapseCount, rewriter);
181 rewriter.replaceOp(origLoop, loopsToTile[0]);
182 LLVM_DEBUG(llvm::dbgs() << "\nAfter uncollapsing:\n"
183 << *loopsToTile[0] << "\n");
184 } else {
185 loopsToTile.push_back(origLoop);
186 }
187
188 // loopsToTile is a vector of perfectly nested loops. The outermost loop
189 // may have multiple IVs but inner loops can only have one IV.
190 // The utility handles unknown tile sizes (*) by using `defaultTileSize`.
191 acc::tileACCLoops(loopsToTile, tileSizes, defaultTileSize, rewriter);
192
193 LLVM_DEBUG(llvm::dbgs() << "\nAfter tiling:\n " << *loopsToTile[0] << "\n");
194 return success();
195 }
196
197private:
198 int32_t defaultTileSize;
199 acc::OpenACCSupport &accSupport;
200};
201
202class ACCLoopTiling : public acc::impl::ACCLoopTilingBase<ACCLoopTiling> {
203public:
204 using ACCLoopTilingBase<ACCLoopTiling>::ACCLoopTilingBase;
205
206 void runOnOperation() override {
207 func::FuncOp funcOp = getOperation();
208 MLIRContext *context = funcOp.getContext();
209 acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
210
212 patterns.insert<ACCLoopTilingImpl>(context, defaultTileSize, accSupport);
214 grc.setUseTopDownTraversal(true);
215 grc.setMaxIterations(1);
216 (void)applyPatternsGreedily(funcOp, std::move(patterns), grc);
217 }
218};
219
220} // namespace
return success()
#define DEBUG_TYPE
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
remark::detail::InFlightRemark emitRemark(Operation *op, const Twine &message, llvm::StringRef category="openacc")
Emit an OpenACC remark.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::acc::LoopOp tileACCLoops(llvm::SmallVector< mlir::acc::LoopOp > &tileLoops, const llvm::SmallVector< mlir::Value > &tileSizes, int32_t defaultTileSize, mlir::RewriterBase &rewriter)
Tile ACC loops according to the given tile sizes.
llvm::SmallVector< mlir::acc::LoopOp > uncollapseLoops(mlir::acc::LoopOp origLoop, unsigned tileCount, unsigned collapseCount, mlir::RewriterBase &rewriter)
Uncollapse tile loops with multiple IVs and collapseCount < tileCount.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...