MLIR  21.0.0git
Go to the documentation of this file.
1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/TypeSwitch.h"
29 namespace mlir {
31 #include "mlir/Dialect/GPU/Transforms/"
32 } // namespace mlir
34 using namespace mlir;
36 namespace {
37 class GpuAsyncRegionPass
38  : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
39  struct ThreadTokenCallback;
40  struct DeferWaitCallback;
41  struct SingleTokenUseCallback;
42  void runOnOperation() override;
43 };
44 } // namespace
46 static bool isTerminator(Operation *op) {
48 }
49 static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
51 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
52 // execute asynchronously.
54  ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
57  for (Operation &op : make_early_inc_range(*block)) {
58  if (failed(visit(&op)))
59  return WalkResult::interrupt();
60  }
61  return WalkResult::advance();
62  }
64 private:
65  // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
66  // create a current token (unless it already exists), and 'thread' that token
67  // through the `op` so that it executes asynchronously.
68  //
69  // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
70  // host-synchronize execution. A `!gpu.async.token` will therefore only be
71  // used inside of its block and GPU execution will always synchronize with
72  // the host at block boundaries.
73  LogicalResult visit(Operation *op) {
74  if (isa<gpu::LaunchOp>(op))
75  return op->emitOpError("replace with gpu.launch_func first");
76  if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
77  if (currentToken)
78  waitOp.addAsyncDependency(currentToken);
79  currentToken = waitOp.getAsyncToken();
80  return success();
81  }
82  builder.setInsertionPoint(op);
83  if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
84  return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
85  if (!currentToken)
86  return success();
87  // Insert host synchronization before terminator or op with side effects.
88  if (isTerminator(op) || hasSideEffects(op))
89  currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
90  return success();
91  }
93  // Replaces asyncOp with a clone that returns a token.
94  LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
95  auto *op = asyncOp.getOperation();
96  auto tokenType = builder.getType<gpu::AsyncTokenType>();
98  // If there is no current token, insert a `gpu.wait async` without
99  // dependencies to create one.
100  if (!currentToken)
101  currentToken = createWaitOp(op->getLoc(), tokenType, {});
102  asyncOp.addAsyncDependency(currentToken);
104  // Return early if op returns a token already.
105  currentToken = asyncOp.getAsyncToken();
106  if (currentToken)
107  return success();
109  // Clone the op to return a token in addition to the other results.
110  SmallVector<Type, 1> resultTypes;
111  resultTypes.reserve(1 + op->getNumResults());
112  copy(op->getResultTypes(), std::back_inserter(resultTypes));
113  resultTypes.push_back(tokenType);
114  auto *newOp = Operation::create(
115  op->getLoc(), op->getName(), resultTypes, op->getOperands(),
117  op->getSuccessors(), op->getNumRegions());
119  // Clone regions into new op.
120  IRMapping mapping;
121  for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
122  std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
124  // Replace the op with the async clone.
125  auto results = newOp->getResults();
126  currentToken = results.back();
127  builder.insert(newOp);
128  op->replaceAllUsesWith(results.drop_back());
129  op->erase();
131  return success();
132  }
134  Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
135  return builder.create<gpu::WaitOp>(loc, resultType, operands)
136  .getAsyncToken();
137  }
139  OpBuilder builder;
141  // The token that represents the current asynchronous dependency. It's valid
142  // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
143  // In between, each gpu::AsyncOpInterface depends on the current token and
144  // produces the new one.
145  Value currentToken = {};
146 };
148 /// Erases `executeOp` and returns a clone with additional `results`.
149 async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
150  ValueRange results) {
151  // Add values to async.yield op.
152  Operation *yieldOp = executeOp.getBody()->getTerminator();
153  yieldOp->insertOperands(yieldOp->getNumOperands(), results);
155  // Construct new result type list with additional types.
156  SmallVector<Type, 2> resultTypes;
157  resultTypes.reserve(executeOp.getNumResults() + results.size());
158  transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
159  [](Type type) {
160  // Extract value type from !async.value.
161  if (auto valueType = dyn_cast<async::ValueType>(type))
162  return valueType.getValueType();
163  assert(isa<async::TokenType>(type) && "expected token type");
164  return type;
165  });
166  transform(results, std::back_inserter(resultTypes),
167  [](Value value) { return value.getType(); });
169  // Clone executeOp with the extra results.
170  OpBuilder builder(executeOp);
171  auto newOp = builder.create<async::ExecuteOp>(
172  executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
173  executeOp.getDependencies(), executeOp.getBodyOperands());
174  IRMapping mapper;
175  newOp.getRegion().getBlocks().clear();
176  executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
178  // Replace executeOp with cloned one.
179  executeOp.getOperation()->replaceAllUsesWith(
180  newOp.getResults().drop_back(results.size()));
181  executeOp.erase();
183  return newOp;
184 }
186 // Callback for `async.execute` ops which tries to push the contained
187 // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
189  // If the `executeOp`s token is used only in `async.execute` or `async.await`
190  // ops, add the region's last `gpu.wait` op to the worklist if it is
191  // synchronous and is the last op with side effects.
192  void operator()(async::ExecuteOp executeOp) {
193  if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
194  return;
195  // async.execute's region is currently restricted to one block.
196  for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
197  if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
198  if (!waitOp.getAsyncToken())
199  worklist.push_back(waitOp);
200  return;
201  }
202  if (hasSideEffects(&op))
203  return;
204  }
205  }
207  // The destructor performs the actual rewrite work.
209  for (size_t i = 0; i < worklist.size(); ++i) {
210  auto waitOp = worklist[i];
211  auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
213  // Erase `gpu.wait` and return async dependencies from execute op instead.
214  SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
215  waitOp.erase();
216  executeOp = addExecuteResults(executeOp, dependencies);
218  // Add the async dependency to each user of the `async.execute` token.
219  auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
220  SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
221  executeOp.getToken().user_end());
222  for (Operation *user : users)
223  addAsyncDependencyAfter(asyncTokens, user);
224  }
225  }
227 private:
228  // Returns whether all token users are either 'async.execute' or 'async.await'
229  // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
230  // 'async.execute' body to it's users. Specifically, we do not allow
231  // terminator users, because it could mean that the `async.execute` is inside
232  // control flow code.
233  static bool areAllUsersExecuteOrAwait(Value token) {
234  return !token.use_empty() &&
235  llvm::all_of(token.getUsers(),
236  llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
237  }
239  // Add the `asyncToken` as dependency as needed after `op`.
240  void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
241  OpBuilder builder(op->getContext());
242  auto loc = op->getLoc();
244  Block::iterator it;
245  SmallVector<Value, 1> tokens;
246  tokens.reserve(asyncTokens.size());
248  .Case<async::AwaitOp>([&](auto awaitOp) {
249  // Add async.await ops to wait for the !gpu.async.tokens.
250  builder.setInsertionPointAfter(op);
251  for (auto asyncToken : asyncTokens)
252  tokens.push_back(
253  builder.create<async::AwaitOp>(loc, asyncToken).getResult());
254  // Set `it` after the inserted async.await ops.
255  it = builder.getInsertionPoint();
256  })
257  .Case<async::ExecuteOp>([&](auto executeOp) {
258  // Set `it` to the beginning of the region and add asyncTokens to the
259  // async.execute operands.
260  it = executeOp.getBody()->begin();
261  executeOp.getBodyOperandsMutable().append(asyncTokens);
262  SmallVector<Type, 1> tokenTypes(
263  asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
264  SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
265  executeOp.getLoc());
266  copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
267  std::back_inserter(tokens));
268  });
270  // Advance `it` to terminator or op with side-effects.
271  it = std::find_if(it, Block::iterator(), [](Operation &op) {
272  return isTerminator(&op) || hasSideEffects(&op);
273  });
275  // If `op` implements the AsyncOpInterface, add `token` to the list of async
276  // dependencies.
277  if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
278  for (auto token : tokens)
279  asyncOp.addAsyncDependency(token);
280  return;
281  }
283  // Otherwise, insert a gpu.wait before 'it'.
284  builder.setInsertionPoint(it->getBlock(), it);
285  auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
287  // If the new waitOp is at the end of an async.execute region, add it to the
288  // worklist. 'operator()(executeOp)' would do the same, but this is faster.
289  auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
290  if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
291  !it->getNextNode())
292  worklist.push_back(waitOp);
293  }
296 };
298 // Callback for `async.execute` ops which repeats !gpu.async.token results
299 // so that each of them is only used once.
301  void operator()(async::ExecuteOp executeOp) {
302  // Extract !gpu.async.token results which have multiple uses.
303  auto multiUseResults = llvm::make_filter_range(
304  executeOp.getBodyResults(), [](OpResult result) {
305  if (result.use_empty() || result.hasOneUse())
306  return false;
307  auto valueType = dyn_cast<async::ValueType>(result.getType());
308  return valueType &&
309  isa<gpu::AsyncTokenType>(valueType.getValueType());
310  });
311  if (multiUseResults.empty())
312  return;
314  // Indices within !async.execute results (i.e. without the async.token).
315  SmallVector<int, 4> indices;
316  transform(multiUseResults, std::back_inserter(indices),
317  [](OpResult result) {
318  return result.getResultNumber() - 1; // Index without token.
319  });
321  for (auto index : indices) {
322  assert(!executeOp.getBodyResults()[index].getUses().empty());
323  // Repeat async.yield token result, one for each use after the first one.
324  auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
325  auto count = std::distance(uses.begin(), uses.end());
326  auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
327  SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
328  executeOp = addExecuteResults(executeOp, operands);
329  // Update 'uses' to refer to the new executeOp.
330  uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
331  auto results = executeOp.getBodyResults().take_back(count);
332  for (auto pair : llvm::zip(uses, results))
333  std::get<0>(pair).set(std::get<1>(pair));
334  }
335  }
336 };
338 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
339 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
340 // execution semantics and that no GPU ops are asynchronous yet.
341 void GpuAsyncRegionPass::runOnOperation() {
342  if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
343  return signalPassFailure();
345  // Collect gpu.wait ops that we can move out of async.execute regions.
346  getOperation().getRegion().walk(DeferWaitCallback());
347  // Makes each !gpu.async.token returned from async.execute op have single use.
348  getOperation().getRegion().walk(SingleTokenUseCallback());
349 }
static bool isTerminator(Operation *op)
async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, ValueRange results)
Erases executeOp and returns a clone with additional results.
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void clear()
Clears all mappings held by the mapper.
Definition: IRMapping.h:79
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:768
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:758
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:501
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
SuccessorRange getSuccessors()
Definition: Operation.h:704
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
void operator()(async::ExecuteOp executeOp)
void operator()(async::ExecuteOp executeOp)