MLIR  16.0.0git
AsyncRegionRewriter.cpp
Go to the documentation of this file.
1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Support/LLVM.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace mlir;
28 namespace {
29 class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
30  struct ThreadTokenCallback;
31  struct DeferWaitCallback;
32  struct SingleTokenUseCallback;
33  void runOnOperation() override;
34 };
35 } // namespace
36 
37 static bool isTerminator(Operation *op) {
39 }
40 static bool hasSideEffects(Operation *op) {
41  return !MemoryEffectOpInterface::hasNoEffect(op);
42 }
43 
44 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
45 // execute asynchronously.
47  ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
48 
50  for (Operation &op : make_early_inc_range(*block)) {
51  if (failed(visit(&op)))
52  return WalkResult::interrupt();
53  }
54  return WalkResult::advance();
55  }
56 
57 private:
58  // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
59  // create a current token (unless it already exists), and 'thread' that token
60  // through the `op` so that it executes asynchronously.
61  //
62  // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
63  // host-synchronize execution. A `!gpu.async.token` will therefore only be
64  // used inside of its block and GPU execution will always synchronize with
65  // the host at block boundaries.
67  if (isa<gpu::LaunchOp>(op))
68  return op->emitOpError("replace with gpu.launch_func first");
69  if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
70  if (currentToken)
71  waitOp.addAsyncDependency(currentToken);
72  currentToken = waitOp.asyncToken();
73  return success();
74  }
75  builder.setInsertionPoint(op);
76  if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
77  return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
78  if (!currentToken)
79  return success();
80  // Insert host synchronization before terminator or op with side effects.
81  if (isTerminator(op) || hasSideEffects(op))
82  currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
83  return success();
84  }
85 
86  // Replaces asyncOp with a clone that returns a token.
87  LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
88  auto *op = asyncOp.getOperation();
89  auto tokenType = builder.getType<gpu::AsyncTokenType>();
90 
91  // If there is no current token, insert a `gpu.wait async` without
92  // dependencies to create one.
93  if (!currentToken)
94  currentToken = createWaitOp(op->getLoc(), tokenType, {});
95  asyncOp.addAsyncDependency(currentToken);
96 
97  // Return early if op returns a token already.
98  currentToken = asyncOp.getAsyncToken();
99  if (currentToken)
100  return success();
101 
102  // Clone the op to return a token in addition to the other results.
103  SmallVector<Type, 1> resultTypes;
104  resultTypes.reserve(1 + op->getNumResults());
105  copy(op->getResultTypes(), std::back_inserter(resultTypes));
106  resultTypes.push_back(tokenType);
107  auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
108  op->getOperands(), op->getAttrDictionary(),
109  op->getSuccessors(), op->getNumRegions());
110 
111  // Clone regions into new op.
112  BlockAndValueMapping mapping;
113  for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
114  std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
115 
116  // Replace the op with the async clone.
117  auto results = newOp->getResults();
118  currentToken = results.back();
119  builder.insert(newOp);
120  op->replaceAllUsesWith(results.drop_back());
121  op->erase();
122 
123  return success();
124  }
125 
126  Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
127  return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
128  }
129 
130  OpBuilder builder;
131 
132  // The token that represents the current asynchronous dependency. It's valid
133  // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
134  // In between, each gpu::AsyncOpInterface depends on the current token and
135  // produces the new one.
136  Value currentToken = {};
137 };
138 
139 /// Erases `executeOp` and returns a clone with additional `results`.
140 async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
141  ValueRange results) {
142  // Add values to async.yield op.
143  Operation *yieldOp = executeOp.getBody()->getTerminator();
144  yieldOp->insertOperands(yieldOp->getNumOperands(), results);
145 
146  // Construct new result type list with additional types.
147  SmallVector<Type, 2> resultTypes;
148  resultTypes.reserve(executeOp.getNumResults() + results.size());
149  transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
150  [](Type type) {
151  // Extract value type from !async.value.
152  if (auto valueType = type.dyn_cast<async::ValueType>())
153  return valueType.getValueType();
154  assert(type.isa<async::TokenType>() && "expected token type");
155  return type;
156  });
157  transform(results, std::back_inserter(resultTypes),
158  [](Value value) { return value.getType(); });
159 
160  // Clone executeOp with the extra results.
161  OpBuilder builder(executeOp);
162  auto newOp = builder.create<async::ExecuteOp>(
163  executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
164  executeOp.dependencies(), executeOp.operands());
165  BlockAndValueMapping mapper;
166  newOp.getRegion().getBlocks().clear();
167  executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
168 
169  // Replace executeOp with cloned one.
170  executeOp.getOperation()->replaceAllUsesWith(
171  newOp.getResults().drop_back(results.size()));
172  executeOp.erase();
173 
174  return newOp;
175 }
176 
177 // Callback for `async.execute` ops which tries to push the contained
178 // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
180  // If the `executeOp`s token is used only in `async.execute` or `async.await`
181  // ops, add the region's last `gpu.wait` op to the worklist if it is
182  // synchronous and is the last op with side effects.
183  void operator()(async::ExecuteOp executeOp) {
184  if (!areAllUsersExecuteOrAwait(executeOp.token()))
185  return;
186  // async.execute's region is currently restricted to one block.
187  for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
188  if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
189  if (!waitOp.asyncToken())
190  worklist.push_back(waitOp);
191  return;
192  }
193  if (hasSideEffects(&op))
194  return;
195  }
196  }
197 
198  // The destructor performs the actual rewrite work.
200  for (size_t i = 0; i < worklist.size(); ++i) {
201  auto waitOp = worklist[i];
202  auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
203 
204  // Erase `gpu.wait` and return async dependencies from execute op instead.
205  SmallVector<Value, 4> dependencies = waitOp.asyncDependencies();
206  waitOp.erase();
207  executeOp = addExecuteResults(executeOp, dependencies);
208 
209  // Add the async dependency to each user of the `async.execute` token.
210  auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
211  SmallVector<Operation *, 4> users(executeOp.token().user_begin(),
212  executeOp.token().user_end());
213  for (Operation *user : users)
214  addAsyncDependencyAfter(asyncTokens, user);
215  }
216  }
217 
218 private:
219  // Returns whether all token users are either 'async.execute' or 'async.await'
220  // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
221  // 'async.execute' body to it's users. Specifically, we do not allow
222  // terminator users, because it could mean that the `async.execute` is inside
223  // control flow code.
224  static bool areAllUsersExecuteOrAwait(Value token) {
225  return !token.use_empty() &&
226  llvm::all_of(token.getUsers(), [](Operation *user) {
227  return isa<async::ExecuteOp, async::AwaitOp>(user);
228  });
229  }
230 
231  // Add the `asyncToken` as dependency as needed after `op`.
232  void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
233  OpBuilder builder(op->getContext());
234  auto loc = op->getLoc();
235 
236  Block::iterator it;
237  SmallVector<Value, 1> tokens;
238  tokens.reserve(asyncTokens.size());
240  .Case<async::AwaitOp>([&](auto awaitOp) {
241  // Add async.await ops to wait for the !gpu.async.tokens.
242  builder.setInsertionPointAfter(op);
243  for (auto asyncToken : asyncTokens)
244  tokens.push_back(
245  builder.create<async::AwaitOp>(loc, asyncToken).result());
246  // Set `it` after the inserted async.await ops.
247  it = builder.getInsertionPoint();
248  })
249  .Case<async::ExecuteOp>([&](auto executeOp) {
250  // Set `it` to the beginning of the region and add asyncTokens to the
251  // async.execute operands.
252  it = executeOp.getBody()->begin();
253  executeOp.operandsMutable().append(asyncTokens);
254  SmallVector<Type, 1> tokenTypes(
255  asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
256  SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
257  executeOp.getLoc());
258  copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
259  std::back_inserter(tokens));
260  });
261 
262  // Advance `it` to terminator or op with side-effects.
263  it = std::find_if(it, Block::iterator(), [](Operation &op) {
264  return isTerminator(&op) || hasSideEffects(&op);
265  });
266 
267  // If `op` implements the AsyncOpInterface, add `token` to the list of async
268  // dependencies.
269  if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
270  for (auto token : tokens)
271  asyncOp.addAsyncDependency(token);
272  return;
273  }
274 
275  // Otherwise, insert a gpu.wait before 'it'.
276  builder.setInsertionPoint(it->getBlock(), it);
277  auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
278 
279  // If the new waitOp is at the end of an async.execute region, add it to the
280  // worklist. 'operator()(executeOp)' would do the same, but this is faster.
281  auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
282  if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
283  !it->getNextNode())
284  worklist.push_back(waitOp);
285  }
286 
288 };
289 
290 // Callback for `async.execute` ops which repeats !gpu.async.token results
291 // so that each of them is only used once.
293  void operator()(async::ExecuteOp executeOp) {
294  // Extract !gpu.async.token results which have multiple uses.
295  auto multiUseResults =
296  llvm::make_filter_range(executeOp.results(), [](OpResult result) {
297  if (result.use_empty() || result.hasOneUse())
298  return false;
299  auto valueType = result.getType().dyn_cast<async::ValueType>();
300  return valueType &&
301  valueType.getValueType().isa<gpu::AsyncTokenType>();
302  });
303  if (multiUseResults.empty())
304  return;
305 
306  // Indices within !async.execute results (i.e. without the async.token).
307  SmallVector<int, 4> indices;
308  transform(multiUseResults, std::back_inserter(indices),
309  [](OpResult result) {
310  return result.getResultNumber() - 1; // Index without token.
311  });
312 
313  for (auto index : indices) {
314  assert(!executeOp.results()[index].getUses().empty());
315  // Repeat async.yield token result, one for each use after the first one.
316  auto uses = llvm::drop_begin(executeOp.results()[index].getUses());
317  auto count = std::distance(uses.begin(), uses.end());
318  auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
319  SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
320  executeOp = addExecuteResults(executeOp, operands);
321  // Update 'uses' to refer to the new executeOp.
322  uses = llvm::drop_begin(executeOp.results()[index].getUses());
323  auto results = executeOp.results().take_back(count);
324  for (auto pair : llvm::zip(uses, results))
325  std::get<0>(pair).set(std::get<1>(pair));
326  }
327  }
328 };
329 
330 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
331 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
332 // execution semantics and that no GPU ops are asynchronous yet.
333 void GpuAsyncRegionPass::runOnOperation() {
334  if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
335  return signalPassFailure();
336 
337  // Collect gpu.wait ops that we can move out of async.execute regions.
338  getOperation().getRegion().walk(DeferWaitCallback());
339  // Makes each !gpu.async.token returned from async.execute op have single use.
340  getOperation().getRegion().walk(SingleTokenUseCallback());
341 }
342 
343 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
344  return std::make_unique<GpuAsyncRegionPass>();
345 }
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
This is a value defined by a result of an operation.
Definition: Value.h:425
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Block represents an ordered list of Operations.
Definition: Block.h:29
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:263
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:536
void operator()(async::ExecuteOp executeOp)
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:707
std::unique_ptr< OperationPass< func::FuncOp > > createGpuAsyncRegionPass()
Rewrites a function region so that GPU ops execute asynchronously.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given &#39;index&#39;.
Definition: Operation.cpp:209
user_range getUsers() const
Definition: Value.h:213
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:414
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
void operator()(async::ExecuteOp executeOp)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:49
OpListType::iterator iterator
Definition: Block.h:131
void clear()
Clears all mappings held by the mapper.
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:437
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided &#39;values&#39;.
Definition: Operation.h:203
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
static WalkResult advance()
Definition: Visitors.h:51
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
static WalkResult interrupt()
Definition: Visitors.h:50
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:359
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:203
static bool hasSideEffects(Operation *op)
SuccessorRange getSuccessors()
Definition: Operation.h:503
static bool isTerminator(Operation *op)
Type getType() const
Return the type of this value.
Definition: Value.h:118
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
type_range getType() const
Definition: ValueRange.cpp:46
async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, ValueRange results)
Erases executeOp and returns a clone with additional results.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
static void visit(Operation *op, DenseSet< Operation *> &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation...
Definition: PDL.cpp:62
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
result_type_range getResultTypes()
Definition: Operation.h:345