28#define GEN_PASS_DEF_BUFFERHOISTINGPASS
29#define GEN_PASS_DEF_BUFFERLOOPHOISTINGPASS
30#define GEN_PASS_DEF_PROMOTEBUFFERSTOSTACKPASS
31#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41 return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
51 if (isa<LoopLikeOpInterface>(op))
56 auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
60 return regionInterface.hasLoop();
72 auto allocOp = dyn_cast<AllocationOpInterface>(op);
80 auto allocOp = dyn_cast<AllocationOpInterface>(op);
89 unsigned maxRankOfAllocatedMemRef) {
90 auto type = dyn_cast<ShapedType>(alloc.
getType());
93 if (!type.hasStaticShape()) {
99 if (type.getRank() <= maxRankOfAllocatedMemRef) {
102 return operand.getDefiningOp<memref::RankOp>();
108 Type elemType = type.getElementType();
110 !isa<ComplexType, IndexType, VectorType>(elemType) &&
111 !isa<DataLayoutTypeInterface>(elemType))
118 std::optional<int64_t> numElements = type.tryGetNumElements();
123 *numElements >
static_cast<int64_t>(maximumSizeInBytes * 8ULL / bitwidth))
125 return *numElements * bitwidth <= maximumSizeInBytes * 8;
132 for (
Value alias : aliases) {
133 for (
auto *use : alias.getUsers()) {
137 if (isa<RegionBranchTerminatorOpInterface>(use) &&
138 use->getParentRegion() == parentRegion)
174struct BufferAllocationHoistingStateBase {
176 DominanceInfo *dominators;
182 Block *placementBlock;
185 BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,
186 Block *placementBlock)
187 : dominators(dominators), allocValue(allocValue),
188 placementBlock(placementBlock) {}
192template <
typename StateT>
195 BufferAllocationHoisting(Operation *op)
196 : BufferPlacementTransformationBase(op), dominators(op),
197 postDominators(op), scopeOp(op) {}
201 SmallVector<Value> allocsAndAllocas;
203 allocsAndAllocas.push_back(std::get<0>(entry));
204 scopeOp->walk([&](memref::AllocaOp op) {
205 allocsAndAllocas.push_back(op.getMemref());
208 for (
auto allocValue : allocsAndAllocas) {
209 if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
211 Operation *definingOp = allocValue.getDefiningOp();
212 assert(definingOp &&
"No defining op");
216 if (!dominators.isReachableFromEntry(allocValue.getParentBlock()))
219 auto resultAliases = aliases.resolve(allocValue);
221 Block *dominatorBlock =
224 StateT state(&dominators, allocValue, allocValue.getParentBlock());
227 Block *dependencyBlock =
nullptr;
231 for (Value depValue : operands) {
232 Block *depBlock = depValue.getParentBlock();
233 if (!dependencyBlock || dominators.dominates(dependencyBlock, depBlock))
234 dependencyBlock = depBlock;
240 Block *placementBlock = findPlacementBlock(
241 state, state.computeUpperBound(dominatorBlock, dependencyBlock));
243 allocValue, placementBlock, liveness);
246 Operation *allocOperation = allocValue.getDefiningOp();
255 Block *findPlacementBlock(StateT &state,
Block *upperBound) {
256 Block *currentBlock = state.placementBlock;
267 (parentBlock = parentOp->
getBlock()) &&
269 dominators.properlyDominates(upperBound, currentBlock))) {
276 idom = dominators.getNode(currentBlock)->getIDom();
278 if (idom && dominators.properlyDominates(parentBlock, idom->getBlock())) {
281 currentBlock = idom->getBlock();
282 state.recordMoveToDominator(currentBlock);
290 !state.isLegalPlacement(parentOp))
294 currentBlock = parentBlock;
295 state.recordMoveToParent(currentBlock);
299 return state.placementBlock;
304 DominanceInfo dominators;
308 PostDominanceInfo postDominators;
311 llvm::DenseMap<Value, Block *> placementBlocks;
321struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
322 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
325 Block *computeUpperBound(
Block *dominatorBlock,
Block *dependencyBlock) {
328 if (!dependencyBlock)
329 return dominatorBlock;
333 return dominators->properlyDominates(dominatorBlock, dependencyBlock)
339 bool isLegalPlacement(Operation *op) {
return !
isLoop(op); }
342 static bool shouldHoistOpType(Operation *op) {
347 void recordMoveToDominator(
Block *block) { placementBlock = block; }
350 void recordMoveToParent(
Block *block) { recordMoveToDominator(block); }
355struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
356 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
359 Block *aliasDominatorBlock =
nullptr;
362 Block *computeUpperBound(
Block *dominatorBlock,
Block *dependencyBlock) {
363 aliasDominatorBlock = dominatorBlock;
366 return dependencyBlock ? dependencyBlock :
nullptr;
374 bool isLegalPlacement(Operation *op) {
376 !dominators->dominates(aliasDominatorBlock, op->
getBlock());
380 static bool shouldHoistOpType(Operation *op) {
386 void recordMoveToDominator(
Block *block) {}
389 void recordMoveToParent(
Block *block) { placementBlock = block; }
399 BufferPlacementPromotion(Operation *op)
400 : BufferPlacementTransformationBase(op) {}
405 Value alloc = std::get<0>(entry);
406 Operation *dealloc = std::get<1>(entry);
411 if (!isSmallAlloc(alloc) || dealloc ||
419 OpBuilder builder(startOperation);
421 if (
auto allocInterface = dyn_cast<AllocationOpInterface>(allocOp)) {
422 std::optional<Operation *> alloca =
423 allocInterface.buildPromotedAlloc(builder, alloc);
440struct BufferHoistingPass
441 :
public bufferization::impl::BufferHoistingPassBase<BufferHoistingPass> {
443 void runOnOperation()
override {
445 BufferAllocationHoisting<BufferAllocationHoistingState> optimizer(
452struct BufferLoopHoistingPass
453 :
public bufferization::impl::BufferLoopHoistingPassBase<
454 BufferLoopHoistingPass> {
456 void runOnOperation()
override {
464class PromoteBuffersToStackPass
465 :
public bufferization::impl::PromoteBuffersToStackPassBase<
466 PromoteBuffersToStackPass> {
470 explicit PromoteBuffersToStackPass(std::function<
bool(Value)> isSmallAlloc)
471 : isSmallAlloc(std::move(isSmallAlloc)) {}
473 LogicalResult
initialize(MLIRContext *context)
override {
474 if (isSmallAlloc ==
nullptr) {
475 isSmallAlloc = [=](Value alloc) {
477 maxRankOfAllocatedMemRef);
483 void runOnOperation()
override {
485 BufferPlacementPromotion optimizer(getOperation());
486 optimizer.promote(isSmallAlloc);
490 std::function<bool(Value)> isSmallAlloc;
496 BufferAllocationHoisting<BufferAllocationLoopHoistingState> optimizer(op);
501 std::function<
bool(
Value)> isSmallAlloc) {
502 return std::make_unique<PromoteBuffersToStackPass>(std::move(isSmallAlloc));
static bool leavesAllocationScope(Region *parentRegion, const BufferViewFlowAnalysis::ValueSetT &aliases)
Checks whether the given aliases leave the allocation scope.
static bool isKnownControlFlowInterface(Operation *op)
Returns true if the given operation implements a known high-level region- based control-flow interfac...
static bool hasAllocationScope(Value alloc, const BufferViewFlowAnalysis &aliasAnalysis)
Checks, if an automated allocation scope for a given alloc value exists.
static bool isSequentialLoop(Operation *op)
Return whether the given operation is a loop with sequential execution semantics.
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool allowAllocDominateBlockHoisting(Operation *op)
Returns true if the given operation implements the AllocationOpInterface and it supports the dominate...
static bool allowAllocLoopHoisting(Operation *op)
Returns true if the given operation implements the AllocationOpInterface and it supports the loop hoi...
static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef)
Check if the size of the allocation is less than the given size.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A straight-forward alias analysis which ensures that all dependencies of all values will be determine...
SmallPtrSet< Value, 16 > ValueSetT
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static Operation * getStartOperation(Value allocValue, Block *placementBlock, const Liveness &liveness)
Get the start operation to place the given alloc value within the specified placement block.
std::tuple< Value, Operation * > AllocEntry
Represents a tuple of allocValue and deallocOperation.
void hoistBuffersFromLoops(Operation *op)
Within the given operation, hoist buffers from loops where possible.
std::unique_ptr< Pass > createPromoteBuffersToStackPass(std::function< bool(Value)> isSmallAlloc)
Creates a pass that promotes heap-based allocations to stack-based ones.
Block * findCommonDominator(Value value, const BufferViewFlowAnalysis::ValueSetT &values, const DominatorT &doms)
Finds a common dominator for the given value while taking the positions of the values in the value se...
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Include the generated interface declarations.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
llvm::function_ref< Fn > function_ref