27 #include "llvm/ADT/TypeSwitch.h"
30 #define GEN_PASS_DEF_GPUASYNCREGIONPASS
31 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
37 class GpuAsyncRegionPass
38 :
public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
39 struct ThreadTokenCallback;
40 struct DeferWaitCallback;
41 struct SingleTokenUseCallback;
42 void runOnOperation()
override;
57 for (
Operation &op : make_early_inc_range(*block)) {
58 if (failed(
visit(&op)))
74 if (isa<gpu::LaunchOp>(op))
75 return op->
emitOpError(
"replace with gpu.launch_func first");
76 if (
auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
78 waitOp.addAsyncDependency(currentToken);
79 currentToken = waitOp.getAsyncToken();
82 builder.setInsertionPoint(op);
83 if (
auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
84 return rewriteAsyncOp(asyncOp);
89 currentToken = createWaitOp(op->
getLoc(),
Type(), {currentToken});
94 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
95 auto *op = asyncOp.getOperation();
101 currentToken = createWaitOp(op->
getLoc(), tokenType, {});
102 asyncOp.addAsyncDependency(currentToken);
105 currentToken = asyncOp.getAsyncToken();
113 resultTypes.push_back(tokenType);
121 for (
auto pair : llvm::zip_first(op->
getRegions(), newOp->getRegions()))
122 std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
125 auto results = newOp->getResults();
126 currentToken = results.back();
127 builder.insert(newOp);
135 return builder.create<gpu::WaitOp>(loc, resultType, operands)
145 Value currentToken = {};
152 Operation *yieldOp = executeOp.getBody()->getTerminator();
157 resultTypes.reserve(executeOp.getNumResults() + results.size());
158 transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
161 if (auto valueType = dyn_cast<async::ValueType>(type))
162 return valueType.getValueType();
163 assert(isa<async::TokenType>(type) &&
"expected token type");
166 transform(results, std::back_inserter(resultTypes),
171 auto newOp = builder.create<async::ExecuteOp>(
172 executeOp.getLoc(),
TypeRange{resultTypes}.drop_front() ,
173 executeOp.getDependencies(), executeOp.getBodyOperands());
175 newOp.getRegion().getBlocks().
clear();
176 executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
179 executeOp.getOperation()->replaceAllUsesWith(
180 newOp.getResults().drop_back(results.size()));
193 if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
196 for (
auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
197 if (
auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
198 if (!waitOp.getAsyncToken())
199 worklist.push_back(waitOp);
209 for (
size_t i = 0; i < worklist.size(); ++i) {
210 auto waitOp = worklist[i];
211 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
219 auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
221 executeOp.getToken().user_end());
223 addAsyncDependencyAfter(asyncTokens, user);
233 static bool areAllUsersExecuteOrAwait(
Value token) {
236 llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
246 tokens.reserve(asyncTokens.size());
248 .Case<async::AwaitOp>([&](
auto awaitOp) {
250 builder.setInsertionPointAfter(op);
251 for (
auto asyncToken : asyncTokens)
253 builder.create<async::AwaitOp>(loc, asyncToken).getResult());
255 it = builder.getInsertionPoint();
257 .Case<async::ExecuteOp>([&](
auto executeOp) {
260 it = executeOp.getBody()->begin();
261 executeOp.getBodyOperandsMutable().append(asyncTokens);
266 copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
267 std::back_inserter(tokens));
277 if (
auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
278 for (
auto token : tokens)
279 asyncOp.addAsyncDependency(token);
284 builder.setInsertionPoint(it->getBlock(), it);
285 auto waitOp = builder.create<gpu::WaitOp>(loc,
Type{}, tokens);
289 auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
290 if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
292 worklist.push_back(waitOp);
303 auto multiUseResults = llvm::make_filter_range(
304 executeOp.getBodyResults(), [](
OpResult result) {
305 if (result.use_empty() || result.hasOneUse())
307 auto valueType = dyn_cast<async::ValueType>(result.getType());
309 isa<gpu::AsyncTokenType>(valueType.getValueType());
311 if (multiUseResults.empty())
316 transform(multiUseResults, std::back_inserter(indices),
321 for (
auto index : indices) {
322 assert(!executeOp.getBodyResults()[index].getUses().empty());
324 auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
325 auto count = std::distance(uses.begin(), uses.end());
326 auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
330 uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
331 auto results = executeOp.getBodyResults().take_back(count);
332 for (
auto pair : llvm::zip(uses, results))
333 std::get<0>(pair).set(std::get<1>(pair));
341 void GpuAsyncRegionPass::runOnOperation() {
342 if (getOperation()->
walk(ThreadTokenCallback(
getContext())).wasInterrupted())
343 return signalPassFailure();
346 getOperation().getRegion().walk(DeferWaitCallback());
348 getOperation().getRegion().walk(SingleTokenUseCallback());
352 return std::make_unique<GpuAsyncRegionPass>();
static bool isTerminator(Operation *op)
async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, ValueRange results)
Erases executeOp and returns a clone with additional results.
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
OpListType::iterator iterator
This is a utility class for mapping one set of IR entities to another.
void clear()
Clears all mappings held by the mapper.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
SuccessorRange getSuccessors()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
user_range getUsers() const
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< OperationPass< func::FuncOp > > createGpuAsyncRegionPass()
Rewrites a function region so that GPU ops execute asynchronously.
void operator()(async::ExecuteOp executeOp)
void operator()(async::ExecuteOp executeOp)
WalkResult operator()(Block *block)
ThreadTokenCallback(MLIRContext &context)