24 #include "llvm/ADT/TypeSwitch.h"
27 #define GEN_PASS_DEF_GPUASYNCREGIONPASS
28 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
34 class GpuAsyncRegionPass
35 :
public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
36 struct ThreadTokenCallback;
37 struct DeferWaitCallback;
38 struct SingleTokenUseCallback;
39 void runOnOperation()
override;
54 for (
Operation &op : make_early_inc_range(*block)) {
71 if (isa<gpu::LaunchOp>(op))
72 return op->
emitOpError(
"replace with gpu.launch_func first");
73 if (
auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
75 waitOp.addAsyncDependency(currentToken);
76 currentToken = waitOp.getAsyncToken();
79 builder.setInsertionPoint(op);
80 if (
auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
81 return rewriteAsyncOp(asyncOp);
86 currentToken = createWaitOp(op->
getLoc(),
Type(), {currentToken});
91 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
92 auto *op = asyncOp.getOperation();
98 currentToken = createWaitOp(op->
getLoc(), tokenType, {});
99 asyncOp.addAsyncDependency(currentToken);
102 currentToken = asyncOp.getAsyncToken();
110 resultTypes.push_back(tokenType);
118 for (
auto pair : llvm::zip_first(op->
getRegions(), newOp->getRegions()))
119 std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
122 auto results = newOp->getResults();
123 currentToken = results.back();
124 builder.insert(newOp);
132 return gpu::WaitOp::create(builder, loc, resultType, operands)
142 Value currentToken = {};
149 Operation *yieldOp = executeOp.getBody()->getTerminator();
154 resultTypes.reserve(executeOp.getNumResults() + results.size());
155 transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
158 if (auto valueType = dyn_cast<async::ValueType>(type))
159 return valueType.getValueType();
160 assert(isa<async::TokenType>(type) &&
"expected token type");
163 transform(results, std::back_inserter(resultTypes),
168 auto newOp = async::ExecuteOp::create(
169 builder, executeOp.getLoc(),
171 executeOp.getDependencies(), executeOp.getBodyOperands());
173 newOp.getRegion().getBlocks().
clear();
174 executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
177 executeOp.getOperation()->replaceAllUsesWith(
178 newOp.getResults().drop_back(results.size()));
191 if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
194 for (
auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
195 if (
auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
196 if (!waitOp.getAsyncToken())
197 worklist.push_back(waitOp);
207 for (
size_t i = 0; i < worklist.size(); ++i) {
208 auto waitOp = worklist[i];
209 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
217 auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
219 executeOp.getToken().user_end());
221 addAsyncDependencyAfter(asyncTokens, user);
231 static bool areAllUsersExecuteOrAwait(
Value token) {
234 llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
244 tokens.reserve(asyncTokens.size());
246 .Case<async::AwaitOp>([&](
auto awaitOp) {
248 builder.setInsertionPointAfter(op);
249 for (
auto asyncToken : asyncTokens)
251 async::AwaitOp::create(builder, loc, asyncToken).getResult());
253 it = builder.getInsertionPoint();
255 .Case<async::ExecuteOp>([&](
auto executeOp) {
258 it = executeOp.getBody()->begin();
259 executeOp.getBodyOperandsMutable().append(asyncTokens);
264 copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
265 std::back_inserter(tokens));
275 if (
auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
276 for (
auto token : tokens)
277 asyncOp.addAsyncDependency(token);
282 builder.setInsertionPoint(it->getBlock(), it);
283 auto waitOp = gpu::WaitOp::create(builder, loc,
Type{}, tokens);
287 auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
288 if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
290 worklist.push_back(waitOp);
301 auto multiUseResults = llvm::make_filter_range(
302 executeOp.getBodyResults(), [](
OpResult result) {
303 if (result.use_empty() || result.hasOneUse())
305 auto valueType = dyn_cast<async::ValueType>(result.getType());
307 isa<gpu::AsyncTokenType>(valueType.getValueType());
309 if (multiUseResults.empty())
314 transform(multiUseResults, std::back_inserter(indices),
319 for (
auto index : indices) {
320 assert(!executeOp.getBodyResults()[index].getUses().empty());
322 auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
323 auto count = std::distance(uses.begin(), uses.end());
324 auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
328 uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
329 auto results = executeOp.getBodyResults().take_back(count);
330 for (
auto pair : llvm::zip(uses, results))
331 std::get<0>(pair).set(std::get<1>(pair));
339 void GpuAsyncRegionPass::runOnOperation() {
340 if (getOperation()->
walk(ThreadTokenCallback(
getContext())).wasInterrupted())
341 return signalPassFailure();
344 getOperation().getRegion().walk(DeferWaitCallback());
346 getOperation().getRegion().walk(SingleTokenUseCallback());
static bool isTerminator(Operation *op)
async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, ValueRange results)
Erases executeOp and returns a clone with additional results.
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
OpListType::iterator iterator
This is a utility class for mapping one set of IR entities to another.
void clear()
Clears all mappings held by the mapper.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
SuccessorRange getSuccessors()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
user_range getUsers() const
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
void operator()(async::ExecuteOp executeOp)
void operator()(async::ExecuteOp executeOp)
WalkResult operator()(Block *block)
ThreadTokenCallback(MLIRContext &context)