MLIR 22.0.0git
AsyncRegionRewriter.cpp
Go to the documentation of this file.
1//===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU dialect pattern rewriters that make GPU op
10// within a region execute asynchronously.
11//
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_GPUASYNCREGIONPASS
28#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32
33namespace {
34class GpuAsyncRegionPass
35 : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
36 struct ThreadTokenCallback;
37 struct DeferWaitCallback;
38 struct SingleTokenUseCallback;
39 void runOnOperation() override;
40};
41} // namespace
43static bool isTerminator(Operation *op) {
45}
46static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
48// Region walk callback which makes GPU ops implementing the AsyncOpInterface
49// execute asynchronously.
51 ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
54 for (Operation &op : make_early_inc_range(*block)) {
55 if (failed(visit(&op)))
56 return WalkResult::interrupt();
57 }
59 }
60
61private:
62 // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
63 // create a current token (unless it already exists), and 'thread' that token
64 // through the `op` so that it executes asynchronously.
65 //
66 // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
67 // host-synchronize execution. A `!gpu.async.token` will therefore only be
68 // used inside of its block and GPU execution will always synchronize with
69 // the host at block boundaries.
70 LogicalResult visit(Operation *op) {
71 if (isa<gpu::LaunchOp>(op))
72 return op->emitOpError("replace with gpu.launch_func first");
73 if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
74 if (currentToken)
75 waitOp.addAsyncDependency(currentToken);
76 currentToken = waitOp.getAsyncToken();
77 return success();
78 }
79 builder.setInsertionPoint(op);
80 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
81 return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
82 if (!currentToken)
83 return success();
84 // Insert host synchronization before terminator or op with side effects.
85 if (isTerminator(op) || hasSideEffects(op))
86 currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
87 return success();
88 }
90 // Replaces asyncOp with a clone that returns a token.
91 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
92 auto *op = asyncOp.getOperation();
93 auto tokenType = builder.getType<gpu::AsyncTokenType>();
94
95 // If there is no current token, insert a `gpu.wait async` without
96 // dependencies to create one.
97 if (!currentToken)
98 currentToken = createWaitOp(op->getLoc(), tokenType, {});
99 asyncOp.addAsyncDependency(currentToken);
100
101 // Return early if op returns a token already.
102 currentToken = asyncOp.getAsyncToken();
103 if (currentToken)
104 return success();
105
106 // Clone the op to return a token in addition to the other results.
107 SmallVector<Type, 1> resultTypes;
108 resultTypes.reserve(1 + op->getNumResults());
109 copy(op->getResultTypes(), std::back_inserter(resultTypes));
110 resultTypes.push_back(tokenType);
111 auto *newOp = Operation::create(
112 op->getLoc(), op->getName(), resultTypes, op->getOperands(),
114 op->getSuccessors(), op->getNumRegions());
115
116 // Clone regions into new op.
117 IRMapping mapping;
118 for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
119 std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
120
121 // Replace the op with the async clone.
122 auto results = newOp->getResults();
123 currentToken = results.back();
124 builder.insert(newOp);
125 op->replaceAllUsesWith(results.drop_back());
126 op->erase();
127
128 return success();
129 }
130
131 Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
132 return gpu::WaitOp::create(builder, loc, resultType, operands)
133 .getAsyncToken();
134 }
135
136 OpBuilder builder;
137
138 // The token that represents the current asynchronous dependency. It's valid
139 // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
140 // In between, each gpu::AsyncOpInterface depends on the current token and
141 // produces the new one.
142 Value currentToken = {};
143};
144
145/// Erases `executeOp` and returns a clone with additional `results`.
146static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
147 ValueRange results) {
148 // Add values to async.yield op.
149 Operation *yieldOp = executeOp.getBody()->getTerminator();
150 yieldOp->insertOperands(yieldOp->getNumOperands(), results);
151
152 // Construct new result type list with additional types.
153 SmallVector<Type, 2> resultTypes;
154 resultTypes.reserve(executeOp.getNumResults() + results.size());
155 transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
156 [](Type type) {
157 // Extract value type from !async.value.
158 if (auto valueType = dyn_cast<async::ValueType>(type))
159 return valueType.getValueType();
160 assert(isa<async::TokenType>(type) && "expected token type");
161 return type;
162 });
163 transform(results, std::back_inserter(resultTypes),
164 [](Value value) { return value.getType(); });
165
166 // Clone executeOp with the extra results.
167 OpBuilder builder(executeOp);
168 auto newOp = async::ExecuteOp::create(
169 builder, executeOp.getLoc(),
170 TypeRange{resultTypes}.drop_front() /*drop token*/,
171 executeOp.getDependencies(), executeOp.getBodyOperands());
172 IRMapping mapper;
173 newOp.getRegion().getBlocks().clear();
174 executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
175
176 // Replace executeOp with cloned one.
177 executeOp.getOperation()->replaceAllUsesWith(
178 newOp.getResults().drop_back(results.size()));
179 executeOp.erase();
180
181 return newOp;
182}
183
184// Callback for `async.execute` ops which tries to push the contained
185// synchronous `gpu.wait` op to the dependencies of the `async.execute`.
187 // If the `executeOp`s token is used only in `async.execute` or `async.await`
188 // ops, add the region's last `gpu.wait` op to the worklist if it is
189 // synchronous and is the last op with side effects.
190 void operator()(async::ExecuteOp executeOp) {
191 if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
192 return;
193 // async.execute's region is currently restricted to one block.
194 for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
195 if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
196 if (!waitOp.getAsyncToken())
197 worklist.push_back(waitOp);
198 return;
199 }
200 if (hasSideEffects(&op))
201 return;
202 }
203 }
204
205 // The destructor performs the actual rewrite work.
207 for (size_t i = 0; i < worklist.size(); ++i) {
208 auto waitOp = worklist[i];
209 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
210
211 // Erase `gpu.wait` and return async dependencies from execute op instead.
212 SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
213 waitOp.erase();
214 executeOp = addExecuteResults(executeOp, dependencies);
215
216 // Add the async dependency to each user of the `async.execute` token.
217 auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
218 SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
219 executeOp.getToken().user_end());
220 for (Operation *user : users)
221 addAsyncDependencyAfter(asyncTokens, user);
222 }
223 }
224
225private:
226 // Returns whether all token users are either 'async.execute' or 'async.await'
227 // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
228 // 'async.execute' body to it's users. Specifically, we do not allow
229 // terminator users, because it could mean that the `async.execute` is inside
230 // control flow code.
231 static bool areAllUsersExecuteOrAwait(Value token) {
232 return !token.use_empty() &&
233 llvm::all_of(token.getUsers(),
234 llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
235 }
236
237 // Add the `asyncToken` as dependency as needed after `op`.
238 void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
239 OpBuilder builder(op->getContext());
240 auto loc = op->getLoc();
241
244 tokens.reserve(asyncTokens.size());
246 .Case<async::AwaitOp>([&](auto awaitOp) {
247 // Add async.await ops to wait for the !gpu.async.tokens.
248 builder.setInsertionPointAfter(op);
249 for (auto asyncToken : asyncTokens)
250 tokens.push_back(
251 async::AwaitOp::create(builder, loc, asyncToken).getResult());
252 // Set `it` after the inserted async.await ops.
253 it = builder.getInsertionPoint();
254 })
255 .Case<async::ExecuteOp>([&](auto executeOp) {
256 // Set `it` to the beginning of the region and add asyncTokens to the
257 // async.execute operands.
258 it = executeOp.getBody()->begin();
259 executeOp.getBodyOperandsMutable().append(asyncTokens);
260 SmallVector<Type, 1> tokenTypes(
261 asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
262 SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
263 executeOp.getLoc());
264 copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
265 std::back_inserter(tokens));
266 });
267
268 // Advance `it` to terminator or op with side-effects.
269 it = std::find_if(it, Block::iterator(), [](Operation &op) {
270 return isTerminator(&op) || hasSideEffects(&op);
271 });
272
273 // If `op` implements the AsyncOpInterface, add `token` to the list of async
274 // dependencies.
275 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
276 for (auto token : tokens)
277 asyncOp.addAsyncDependency(token);
278 return;
279 }
280
281 // Otherwise, insert a gpu.wait before 'it'.
282 builder.setInsertionPoint(it->getBlock(), it);
283 auto waitOp = gpu::WaitOp::create(builder, loc, Type{}, tokens);
284
285 // If the new waitOp is at the end of an async.execute region, add it to the
286 // worklist. 'operator()(executeOp)' would do the same, but this is faster.
287 auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
288 if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
289 !it->getNextNode())
290 worklist.push_back(waitOp);
291 }
292
293 SmallVector<gpu::WaitOp, 8> worklist;
294};
295
296// Callback for `async.execute` ops which repeats !gpu.async.token results
297// so that each of them is only used once.
299 void operator()(async::ExecuteOp executeOp) {
300 // Extract !gpu.async.token results which have multiple uses.
301 auto multiUseResults = llvm::make_filter_range(
302 executeOp.getBodyResults(), [](OpResult result) {
303 if (result.use_empty() || result.hasOneUse())
304 return false;
305 auto valueType = dyn_cast<async::ValueType>(result.getType());
306 return valueType &&
307 isa<gpu::AsyncTokenType>(valueType.getValueType());
308 });
309 if (multiUseResults.empty())
310 return;
311
312 // Indices within !async.execute results (i.e. without the async.token).
314 transform(multiUseResults, std::back_inserter(indices),
315 [](OpResult result) {
316 return result.getResultNumber() - 1; // Index without token.
317 });
318
319 for (auto index : indices) {
320 assert(!executeOp.getBodyResults()[index].getUses().empty());
321 // Repeat async.yield token result, one for each use after the first one.
322 auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
323 auto count = std::distance(uses.begin(), uses.end());
324 auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
325 SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
326 executeOp = addExecuteResults(executeOp, operands);
327 // Update 'uses' to refer to the new executeOp.
328 uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
329 auto results = executeOp.getBodyResults().take_back(count);
330 for (auto pair : llvm::zip(uses, results))
331 std::get<0>(pair).set(std::get<1>(pair));
332 }
333 }
334};
335
336// Replaces synchronous GPU ops in the op's region with asynchronous ones and
337// inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
338// execution semantics and that no GPU ops are asynchronous yet.
339void GpuAsyncRegionPass::runOnOperation() {
340 if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
341 return signalPassFailure();
342
343 // Collect gpu.wait ops that we can move out of async.execute regions.
344 getOperation().getRegion().walk(DeferWaitCallback());
345 // Makes each !gpu.async.token returned from async.execute op have single use.
346 getOperation().getRegion().walk(SingleTokenUseCallback());
347}
return success()
static bool isTerminator(Operation *op)
static bool hasSideEffects(Operation *op)
static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, ValueRange results)
Erases executeOp and returns a clone with additional results.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
b getContext())
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void clear()
Clears all mappings held by the mapper.
Definition IRMapping.h:79
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:757
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition Operation.h:501
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition Operation.h:272
SuccessorRange getSuccessors()
Definition Operation.h:703
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
void operator()(async::ExecuteOp executeOp)
void operator()(async::ExecuteOp executeOp)