18#include "llvm/ADT/SetOperations.h"
22#define GEN_PASS_DEF_CHECKUSESPASS
23#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
37template <
typename FnTy>
39getReachableImpl(
Block *block, FnTy getNextNodes,
41 auto [it,
inserted] = cache.try_emplace(block);
43 return it->getSecond();
47 worklist.push_back(block);
48 while (!worklist.empty()) {
49 Block *current = worklist.pop_back_val();
50 for (
Block *predecessor : getNextNodes(current)) {
53 if (reachable.insert(predecessor).second)
54 worklist.push_back(predecessor);
81class TransformOpMemFreeAnalysis {
86 explicit TransformOpMemFreeAnalysis(
Operation *root) {
88 if (isa<transform::TransformOpInterface>(op)) {
89 collectFreedValues(op);
98 class PotentialDeleters {
101 static PotentialDeleters live() {
return PotentialDeleters({}); }
104 static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
105 return PotentialDeleters(deleters);
110 explicit operator bool()
const {
return !deleters.empty(); }
114 PotentialDeleters &
operator|=(
const PotentialDeleters &other) {
115 llvm::append_range(deleters, other.deleters);
120 ArrayRef<Operation *> getOps()
const {
return deleters; }
124 explicit PotentialDeleters(ArrayRef<Operation *> ops) {
125 llvm::append_range(deleters, ops);
129 SmallVector<Operation *> deleters;
136 PotentialDeleters isUseLive(OpOperand &operand) {
137 const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.
get()];
138 if (deleters.empty())
149 Operation *valueSource =
150 isa<OpResult>(operand.
get())
153 auto iface = cast<MemoryEffectOpInterface>(valueSource);
154 SmallVector<MemoryEffects::EffectInstance> instances;
155 iface.getEffectsOnResource(transform::TransformMappingResource::get(),
157 assert((isa<BlockArgument>(operand.
get()) ||
159 "expected the op defining the value to have an allocation effect "
165 SmallVector<Operation *> ancestors;
166 Operation *ancestor = operand.
getOwner();
168 ancestors.push_back(ancestor);
173 std::reverse(ancestors.begin(), ancestors.end());
178 for (Operation *ancestor : ancestors) {
183 bool isOutermost = ancestor == ancestors.front();
184 bool isFromBlockPartial = isOutermost && isa<OpResult>(operand.
get());
191 if (isFromBlockPartial) {
192 bool defUseSameBlock = ancestor->
getBlock() == defBlock;
196 if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
198 defUseSameBlock ? ancestor :
nullptr))
199 return potentialDeleters;
205 if (!isFromBlockPartial || ancestor->
getBlock() != defBlock) {
206 if (PotentialDeleters potentialDeleters =
207 isFreedInBlockBefore(ancestor, operand.
get()))
208 return potentialDeleters;
222 if (PotentialDeleters potentialDeleters =
223 isMaybeFreedOnPaths(from, ancestorBlock, operand.
get(),
224 !isFromBlockPartial))
225 return potentialDeleters;
232 static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
233 return PotentialDeleters::maybeFreed(deleters);
235 static PotentialDeleters live() {
return PotentialDeleters::live(); }
241 isFreedBetween(Value value, Operation *first, Operation *last,
242 llvm::function_ref<Operation *(Operation *)> getNext)
const {
243 auto it = freedBy.find(value);
244 if (it == freedBy.end())
246 const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
247 for (Operation *op = getNext(first); op != last; op = getNext(op)) {
248 if (deleters.contains(op))
249 return maybeFreed(op);
258 PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
259 Operation *before =
nullptr)
const {
260 return isFreedBetween(value, root, before,
261 [](Operation *op) {
return op->getNextNode(); });
266 PotentialDeleters isFreedInBlockBefore(Operation *root, Value value)
const {
267 return isFreedBetween(value, root,
nullptr,
268 [](Operation *op) {
return op->getPrevNode(); });
279 PotentialDeleters isMaybeFreedOnPaths(
Block *from,
Block *to, Value value,
280 bool alwaysIncludeFrom) {
284 const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
285 if (!sources.contains(from))
288 llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
289 llvm::set_intersect(reachable, sources);
293 if (alwaysIncludeFrom)
294 reachable.insert(from);
298 PotentialDeleters potentialDeleters = live();
299 for (
Block *block : reachable) {
300 for (Operation &op : *block) {
301 if (freedBy[value].count(&op))
302 potentialDeleters |= maybeFreed(&op);
305 return potentialDeleters;
311 const llvm::SmallPtrSet<Block *, 4> &getReachable(
Block *block) {
312 return getReachableImpl(
313 block, [](
Block *
b) {
return b->getSuccessors(); }, reachableCache);
318 const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(
Block *block) {
319 return getReachableImpl(
320 block, [](
Block *
b) {
return b->getPredecessors(); },
325 template <
typename EffectTy>
326 static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
328 return llvm::any_of(instances,
330 return instance.
getValue() == value &&
337 void collectFreedValues(Operation *root) {
338 SmallVector<MemoryEffects::EffectInstance> instances;
339 root->
walk([&](Operation *child) {
340 if (isa<transform::PatternDescriptorOpInterface>(child))
344 auto iface = cast<MemoryEffectOpInterface>(child);
346 iface.getEffectsOnResource(transform::TransformMappingResource::get(),
358 Operation *parent = child;
360 freedBy[operand].insert(parent);
384 void runOnOperation()
override {
385 auto &
analysis = getAnalysis<TransformOpMemFreeAnalysis>();
387 getOperation()->walk([&](Operation *child) {
389 TransformOpMemFreeAnalysis::PotentialDeleters deleters =
396 <<
" may be used after free";
398 for (Operation *d : deleters.getOps()) {
399 diag.attachNote(d->getLoc()) <<
"freed here";
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::string diag(const llvm::Value &value)
template bool mlir::hasEffect< MemoryEffects::Allocate >(Operation *)
template bool mlir::hasEffect< MemoryEffects::Free >(Operation *)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Region * getParentRegion()
Returns the region to which the instruction belongs.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
ChangeResult & operator|=(ChangeResult &lhs, ChangeResult rhs)