MLIR  21.0.0git
BufferUtils.cpp
Go to the documentation of this file.
1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Operation.h"
20 #include <optional>
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 
25 //===----------------------------------------------------------------------===//
26 // BufferPlacementAllocs
27 //===----------------------------------------------------------------------===//
28 
29 /// Get the start operation to place the given alloc value withing the
30 // specified placement block.
32  Block *placementBlock,
33  const Liveness &liveness) {
34  // We have to ensure that we place the alloc before its first use in this
35  // block.
36  const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
37  Operation *startOperation = livenessInfo.getStartOperation(allocValue);
38  // Check whether the start operation lies in the desired placement block.
39  // If not, we will use the terminator as this is the last operation in
40  // this block.
41  if (startOperation->getBlock() != placementBlock) {
42  Operation *opInPlacementBlock =
43  placementBlock->findAncestorOpInBlock(*startOperation);
44  startOperation = opInPlacementBlock ? opInPlacementBlock
45  : placementBlock->getTerminator();
46  }
47 
48  return startOperation;
49 }
50 
51 /// Initializes the internal list by discovering all supported allocation
52 /// nodes.
54 
55 /// Searches for and registers all supported allocation entries.
56 void BufferPlacementAllocs::build(Operation *op) {
57  op->walk([&](MemoryEffectOpInterface opInterface) {
58  // Try to find a single allocation result.
60  opInterface.getEffects(effects);
61 
63  llvm::copy_if(
64  effects, std::back_inserter(allocateResultEffects),
66  Value value = it.getValue();
67  return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
68  isa<OpResult>(value) &&
69  it.getResource() !=
71  });
72  // If there is one result only, we will be able to move the allocation and
73  // (possibly existing) deallocation ops.
74  if (allocateResultEffects.size() != 1)
75  return;
76  // Get allocation result.
77  Value allocValue = allocateResultEffects[0].getValue();
78  // Find the associated dealloc value and register the allocation entry.
79  std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
80  // If the allocation has > 1 dealloc associated with it, skip handling it.
81  if (!dealloc)
82  return;
83  allocs.push_back(std::make_tuple(allocValue, *dealloc));
84  });
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // BufferPlacementTransformationBase
89 //===----------------------------------------------------------------------===//
90 
91 /// Constructs a new transformation base using the given root operation.
93  Operation *op)
94  : aliases(op), allocs(op), liveness(op) {}
95 
96 //===----------------------------------------------------------------------===//
97 // BufferPlacementTransformationBase
98 //===----------------------------------------------------------------------===//
99 
100 FailureOr<memref::GlobalOp>
101 bufferization::getGlobalFor(arith::ConstantOp constantOp,
102  SymbolTableCollection &symbolTables,
103  uint64_t alignment, Attribute memorySpace) {
104  auto type = cast<RankedTensorType>(constantOp.getType());
105  auto moduleOp = constantOp->getParentOfType<ModuleOp>();
106  if (!moduleOp)
107  return failure();
108 
109  // If we already have a global for this constant value, no need to do
110  // anything else.
111  for (Operation &op : moduleOp.getRegion().getOps()) {
112  auto globalOp = dyn_cast<memref::GlobalOp>(&op);
113  if (!globalOp)
114  continue;
115  if (!globalOp.getInitialValue().has_value())
116  continue;
117  uint64_t opAlignment = globalOp.getAlignment().value_or(0);
118  Attribute initialValue = globalOp.getInitialValue().value();
119  if (opAlignment == alignment && initialValue == constantOp.getValue())
120  return globalOp;
121  }
122 
123  // Create a builder without an insertion point. We will insert using the
124  // symbol table to guarantee unique names.
125  OpBuilder globalBuilder(moduleOp.getContext());
126  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
127 
128  // Create a pretty name.
129  SmallString<64> buf;
130  llvm::raw_svector_ostream os(buf);
131  interleave(type.getShape(), os, "x");
132  os << "x" << type.getElementType();
133 
134  // Add an optional alignment to the global memref.
135  IntegerAttr memrefAlignment =
136  alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
137  : IntegerAttr();
138 
139  // Memref globals always have an identity layout.
140  auto memrefType =
141  cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type));
142  if (memorySpace)
143  memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
144  auto global = globalBuilder.create<memref::GlobalOp>(
145  constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
146  /*sym_visibility=*/globalBuilder.getStringAttr("private"),
147  /*type=*/memrefType,
148  /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
149  /*constant=*/true,
150  /*alignment=*/memrefAlignment);
151  symbolTable.insert(global);
152  // The symbol table inserts at the end of the module, but globals are a bit
153  // nicer if they are at the beginning.
154  global->moveBefore(&moduleOp.front());
155  return global;
156 }
157 
158 namespace mlir::bufferization {
160  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
162 
163  symbolTable.remove(op);
164 }
165 
167  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
169 
170  symbolTable.insert(op);
171 }
172 } // namespace mlir::bufferization
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:74
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
IntegerType getI64Type()
Definition: Builders.cpp:64
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
This class represents liveness information on block level.
Definition: Liveness.h:99
Operation * getStartOperation(Value value) const
Gets the start operation for the given value.
Definition: Liveness.cpp:364
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition: Liveness.cpp:225
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
iterator_range< OpIterator > getOps()
Definition: Region.h:172
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static AutomaticAllocationScopeResource * get()
Returns a unique instance for the given effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static Operation * getStartOperation(Value allocValue, Block *placementBlock, const Liveness &liveness)
Get the start operation to place the given alloc value within the specified placement block.
Definition: BufferUtils.cpp:31
BufferPlacementAllocs(Operation *op)
Initializes the internal list by discovering all supported allocation nodes.
Definition: BufferUtils.cpp:53
BufferPlacementTransformationBase(Operation *op)
Constructs a new operation base using the given root operation.
Definition: BufferUtils.cpp:92
BufferizationState provides information about the state of the IR during the bufferization process.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
void insertSymbol(Operation *op, BufferizationState &state)
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, SymbolTableCollection &symbolTables, uint64_t alignment, Attribute memorySpace={})
void removeSymbol(Operation *op, BufferizationState &state)
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...