25 #include "llvm/ADT/TypeSwitch.h" 29 class GpuAsyncRegionPass :
public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
30 struct ThreadTokenCallback;
31 struct DeferWaitCallback;
32 struct SingleTokenUseCallback;
33 void runOnOperation()
override;
41 return !MemoryEffectOpInterface::hasNoEffect(op);
50 for (
Operation &op : make_early_inc_range(*block)) {
67 if (isa<gpu::LaunchOp>(op))
68 return op->
emitOpError(
"replace with gpu.launch_func first");
69 if (
auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
71 waitOp.addAsyncDependency(currentToken);
72 currentToken = waitOp.asyncToken();
75 builder.setInsertionPoint(op);
76 if (
auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
77 return rewriteAsyncOp(asyncOp);
82 currentToken = createWaitOp(op->
getLoc(),
Type(), {currentToken});
88 auto *op = asyncOp.getOperation();
94 currentToken = createWaitOp(op->
getLoc(), tokenType, {});
95 asyncOp.addAsyncDependency(currentToken);
98 currentToken = asyncOp.getAsyncToken();
106 resultTypes.push_back(tokenType);
113 for (
auto pair : llvm::zip_first(op->
getRegions(), newOp->getRegions()))
114 std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
117 auto results = newOp->getResults();
118 currentToken = results.back();
119 builder.insert(newOp);
127 return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
136 Value currentToken = {};
143 Operation *yieldOp = executeOp.getBody()->getTerminator();
148 resultTypes.reserve(executeOp.getNumResults() + results.size());
149 transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
152 if (
auto valueType = type.dyn_cast<async::ValueType>())
153 return valueType.getValueType();
154 assert(type.isa<async::TokenType>() &&
"expected token type");
157 transform(results, std::back_inserter(resultTypes),
162 auto newOp = builder.create<async::ExecuteOp>(
163 executeOp.getLoc(),
TypeRange{resultTypes}.drop_front() ,
164 executeOp.dependencies(), executeOp.operands());
166 newOp.getRegion().getBlocks().
clear();
167 executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
170 executeOp.getOperation()->replaceAllUsesWith(
171 newOp.getResults().drop_back(results.size()));
184 if (!areAllUsersExecuteOrAwait(executeOp.token()))
187 for (
auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
188 if (
auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
189 if (!waitOp.asyncToken())
190 worklist.push_back(waitOp);
200 for (
size_t i = 0; i < worklist.size(); ++i) {
201 auto waitOp = worklist[i];
202 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
210 auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
212 executeOp.token().user_end());
214 addAsyncDependencyAfter(asyncTokens, user);
224 static bool areAllUsersExecuteOrAwait(
Value token) {
227 return isa<async::ExecuteOp, async::AwaitOp>(user);
238 tokens.reserve(asyncTokens.size());
240 .Case<async::AwaitOp>([&](
auto awaitOp) {
242 builder.setInsertionPointAfter(op);
243 for (
auto asyncToken : asyncTokens)
245 builder.create<async::AwaitOp>(loc, asyncToken).result());
247 it = builder.getInsertionPoint();
249 .Case<async::ExecuteOp>([&](
auto executeOp) {
252 it = executeOp.getBody()->begin();
253 executeOp.operandsMutable().append(asyncTokens);
258 copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
259 std::back_inserter(tokens));
269 if (
auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
270 for (
auto token : tokens)
271 asyncOp.addAsyncDependency(token);
276 builder.setInsertionPoint(it->getBlock(), it);
277 auto waitOp = builder.create<gpu::WaitOp>(loc,
Type{}, tokens);
281 auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
282 if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
284 worklist.push_back(waitOp);
295 auto multiUseResults =
296 llvm::make_filter_range(executeOp.results(), [](
OpResult result) {
297 if (result.use_empty() || result.hasOneUse())
299 auto valueType = result.getType().dyn_cast<async::ValueType>();
303 if (multiUseResults.empty())
308 transform(multiUseResults, std::back_inserter(indices),
313 for (
auto index : indices) {
314 assert(!executeOp.results()[index].getUses().empty());
316 auto uses = llvm::drop_begin(executeOp.results()[index].getUses());
317 auto count = std::distance(uses.begin(), uses.end());
318 auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
322 uses = llvm::drop_begin(executeOp.results()[index].getUses());
323 auto results = executeOp.results().take_back(count);
324 for (
auto pair : llvm::zip(uses, results))
325 std::get<0>(pair).
set(std::get<1>(pair));
333 void GpuAsyncRegionPass::runOnOperation() {
335 return signalPassFailure();
344 return std::make_unique<GpuAsyncRegionPass>();
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This is a value defined by a result of an operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Block represents an ordered list of Operations.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
unsigned getNumOperands()
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
void operator()(async::ExecuteOp executeOp)
This class provides the API for ops that are known to be terminators.
std::unique_ptr< OperationPass< func::FuncOp > > createGpuAsyncRegionPass()
Rewrites a function region so that GPU ops execute asynchronously.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
user_range getUsers() const
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext()
Return the context this operation is associated with.
void operator()(async::ExecuteOp executeOp)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
WalkResult operator()(Block *block)
OpListType::iterator iterator
void clear()
Clears all mappings held by the mapper.
unsigned getResultNumber() const
Returns the number of this result.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
This class provides an abstraction over the various different ranges of value types.
static WalkResult advance()
Location getLoc()
The source location the operation was defined or derived from.
ThreadTokenCallback(MLIRContext &context)
static WalkResult interrupt()
A utility result that is used to signal how to proceed with an ongoing walk:
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
static bool hasSideEffects(Operation *op)
SuccessorRange getSuccessors()
static bool isTerminator(Operation *op)
Type getType() const
Return the type of this value.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
type_range getType() const
async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, ValueRange results)
Erases executeOp and returns a clone with additional results.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getNumResults()
Return the number of results held by this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
OperationName getName()
The name of an operation is the key identifier for it.
static void visit(Operation *op, DenseSet< Operation *> &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation...
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
result_type_range getResultTypes()