27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
31 #define GEN_PASS_DEF_GPUELIMINATEBARRIERS
32 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
38 #define DEBUG_TYPE "gpu-erase-barriers"
39 #define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias"
41 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
42 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
51 return isa<memref::AssumeAlignmentOp>(op);
57 if (op->
hasAttr(
"__parallel_region_boundary_for_test"))
60 return isa<GPUFuncOp, LaunchOp>(op);
71 return isa<scf::IfOp, memref::AllocaScopeOp>(op);
78 return isa_and_nonnull<memref::AllocOp, memref::AllocaOp>(op);
85 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
86 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
87 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
88 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
98 bool ignoreBarriers =
true) {
101 if (ignoreBarriers && isa<BarrierOp>(op))
112 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
114 iface.getEffects(localEffects);
115 llvm::append_range(effects, localEffects);
120 for (
auto &block : region) {
121 for (
auto &innerOp : block)
140 bool stopAtBarrier) {
144 for (
Operation *it = op->getPrevNode(); it !=
nullptr;
145 it = it->getPrevNode()) {
146 if (isa<BarrierOp>(it)) {
167 bool stopAtBarrier) {
173 if (region && !llvm::hasSingleElement(region->
getBlocks())) {
211 bool conservative =
false;
223 return !conservative;
231 bool stopAtBarrier) {
235 for (
Operation *it = op->getNextNode(); it !=
nullptr;
236 it = it->getNextNode()) {
237 if (isa<BarrierOp>(it)) {
257 bool stopAtBarrier) {
263 if (region && !llvm::hasSingleElement(region->
getBlocks())) {
305 bool conservative =
false;
317 return !conservative;
327 bool shouldContinue =
329 .Case<memref::CastOp, memref::SubViewOp, memref::ViewOp>(
334 .Case<memref::TransposeOp>([&](
auto op) {
338 .Case<memref::CollapseShapeOp, memref::ExpandShapeOp>([&](
auto op) {
342 .Default([](
Operation *) {
return false; });
351 auto arg = dyn_cast<BlockArgument>(v);
352 return arg && isa<FunctionOpInterface>(arg.getOwner()->getParentOp());
361 [](ViewLikeOpInterface viewLike) {
return viewLike.getViewSource(); })
362 .Case([](CastOpInterface castLike) {
return castLike->getOperand(0); })
364 .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
365 [](
auto op) {
return op.getSrc(); })
376 .Case<memref::StoreOp, vector::TransferWriteOp>(
377 [&](
auto op) {
return op.getValue() == v; })
378 .Case<vector::StoreOp, vector::MaskedStoreOp>(
379 [&](
auto op) {
return op.getValueToStore() == v; })
381 .Case([](memref::DeallocOp) {
return false; })
383 .Default([](
Operation *) {
return std::nullopt; });
392 while (!todo.empty()) {
393 Value v = todo.pop_back_val();
396 auto iface = dyn_cast<MemoryEffectOpInterface>(user);
399 iface.getEffects(effects);
400 if (llvm::all_of(effects,
402 return isa<MemoryEffects::Read>(effect.
getEffect());
416 if (!knownCaptureStatus || *knownCaptureStatus)
453 if (first == second) {
459 if (
auto globFirst = first.
getDefiningOp<memref::GetGlobalOp>()) {
460 if (
auto globSecond = second.
getDefiningOp<memref::GetGlobalOp>()) {
461 return globFirst.getNameAttr() == globSecond.getNameAttr();
466 auto isNoaliasFuncArgument = [](
Value value) {
467 auto bbArg = dyn_cast<BlockArgument>(value);
470 auto iface = dyn_cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
474 return iface.getArgAttr(bbArg.getArgNumber(),
"llvm.noalias") !=
nullptr;
476 if (isNoaliasFuncArgument(first) && isNoaliasFuncArgument(second))
481 bool isGlobal[] = {first.
getDefiningOp<memref::GetGlobalOp>() !=
nullptr,
487 if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1]))
493 if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0]))
551 if (isa<MemoryEffects::Read>(before.getEffect()) &&
552 isa<MemoryEffects::Read>(after.getEffect())) {
560 if (isa<MemoryEffects::Allocate>(before.getEffect()) ||
561 isa<MemoryEffects::Allocate>(after.getEffect())) {
573 if (isa<MemoryEffects::Free>(before.getEffect()))
578 DBGS() <<
"found a conflict between (before): " << before.getValue()
579 <<
" read:" << isa<MemoryEffects::Read>(before.getEffect())
580 <<
" write:" << isa<MemoryEffects::Write>(before.getEffect())
582 << isa<MemoryEffects::Allocate>(before.getEffect()) <<
" free:"
583 << isa<MemoryEffects::Free>(before.getEffect()) <<
"\n");
585 DBGS() <<
"and (after): " << after.getValue()
586 <<
" read:" << isa<MemoryEffects::Read>(after.getEffect())
587 <<
" write:" << isa<MemoryEffects::Write>(after.getEffect())
588 <<
" alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
589 <<
" free:" << isa<MemoryEffects::Free>(after.getEffect())
603 LogicalResult matchAndRewrite(BarrierOp barrier,
605 LLVM_DEBUG(
DBGS() <<
"checking the necessity of: " << barrier <<
" "
606 << barrier.getLoc() <<
"\n");
615 LLVM_DEBUG(
DBGS() <<
"the surrounding barriers are sufficient, removing "
621 LLVM_DEBUG(
DBGS() <<
"barrier is necessary: " << barrier <<
" "
622 << barrier.getLoc() <<
"\n");
627 class GpuEliminateBarriersPass
628 :
public impl::GpuEliminateBarriersBase<GpuEliminateBarriersPass> {
629 void runOnOperation()
override {
630 auto funcOp = getOperation();
634 return signalPassFailure();
static bool isSequentialLoopLike(Operation *op)
Returns true if the op behaves like a sequential loop, e.g., the control flow "wraps around" from the...
static bool isFunctionArgument(Value v)
Returns true if the value is defined as a function argument.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Value propagatesCapture(Operation *op)
Returns the operand that the operation "propagates" through it for capture purposes.
static bool hasSingleExecutionBody(Operation *op)
Returns true if the regions of the op are guaranteed to be executed at most once.
static bool producesDistinctBase(Operation *op)
Returns true if the operation is known to produce a pointer-like object distinct from any other objec...
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
static bool isKnownNoEffectsOpWithoutInterface(Operation *op)
Implement the MemoryEffectsOpInterface in the suitable way.
static bool getEffectsBeforeInBlock(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)
Get all effects before the given operation caused by other operations in the same block.
static bool isParallelRegionBoundary(Operation *op)
Returns true if the op is defines the parallel region that is subject to barrier synchronization.
static bool getEffectsAfter(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)
Collects memory effects from operations that may be executed after op in a trivial structured control...
static std::optional< bool > getKnownCapturingStatus(Operation *op, Value v)
Returns true if the given operation is known to capture the given value, false if it is known not to ...
static bool collectEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool ignoreBarriers=true)
Collect the memory effects of the given op in 'effects'.
static bool haveConflictingEffects(ArrayRef< MemoryEffects::EffectInstance > beforeEffects, ArrayRef< MemoryEffects::EffectInstance > afterEffects)
Returns true if any of the "before" effect instances has a conflict with any "after" instance for the...
static bool getEffectsAfterInBlock(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)
Get all effects after the given operation caused by other operations in the same block.
static void addAllValuelessEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with all memory effects without associating them to a specific value.
static bool maybeCaptured(Value v)
Returns true if the value may be captured by any of its users, i.e., if the user may be storing this ...
static bool getEffectsBefore(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)
Collects memory effects from operations that may be executed before op in a trivial structured contro...
static MLIRContext * getContext(OpFoldResult val)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
TypeID getResourceID() const
Return the unique identifier for the base resource class.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)
Erase barriers that do not enforce conflicting memory side effects.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...