28 namespace bufferization {
29 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
30 #define GEN_PASS_DEF_ONESHOTBUFFERIZE
31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
35 #define DEBUG_TYPE "bufferize"
42 static LayoutMapOption parseLayoutMapOption(
const std::string &s) {
43 if (s ==
"fully-dynamic-layout-map")
44 return LayoutMapOption::FullyDynamicLayoutMap;
45 if (s ==
"identity-layout-map")
46 return LayoutMapOption::IdentityLayoutMap;
47 if (s ==
"infer-layout-map")
48 return LayoutMapOption::InferLayoutMap;
49 llvm_unreachable(
"invalid layout map option");
53 parseHeuristicOption(
const std::string &s) {
58 if (s ==
"bottom-up-from-terminators")
63 llvm_unreachable(
"invalid analysisheuristic option");
66 struct OneShotBufferizePass
67 :
public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
68 OneShotBufferizePass() =
default;
75 .
insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
78 void runOnOperation()
override {
90 parseLayoutMapOption(functionBoundaryTypeConversion));
92 if (mustInferMemorySpace && useEncodingForMemorySpace) {
94 <<
"only one of 'must-infer-memory-space' and "
95 "'use-encoding-for-memory-space' are allowed in "
97 return signalPassFailure();
100 if (mustInferMemorySpace) {
102 [](
TensorType t) -> std::optional<Attribute> {
107 if (useEncodingForMemorySpace) {
109 [](
TensorType t) -> std::optional<Attribute> {
110 if (
auto rtt = dyn_cast<RankedTensorType>(t))
111 return rtt.getEncoding();
124 LayoutMapOption unknownTypeConversionOption =
125 parseLayoutMapOption(unknownTypeConversion);
126 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
128 "Invalid option: 'infer-layout-map' is not a valid value for "
129 "'unknown-type-conversion'");
130 return signalPassFailure();
134 auto tensorType = cast<TensorType>(value.
getType());
135 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
137 tensorType, memorySpace);
138 assert(unknownTypeConversionOption ==
139 LayoutMapOption::FullyDynamicLayoutMap &&
140 "invalid layout map option");
148 if (this->dialectFilter.hasValue())
149 return llvm::is_contained(this->dialectFilter,
150 op->getDialect()->getNamespace());
165 "Invalid option: 'copy-before-write' cannot be used with "
166 "'test-analysis-only'");
167 return signalPassFailure();
173 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
174 return signalPassFailure();
180 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
181 return signalPassFailure();
185 ModuleOp moduleOp = getOperation();
194 "Invalid option: 'no-analysis-func-filter' requires "
195 "'bufferize-function-boundaries'");
196 return signalPassFailure();
211 std::optional<OneShotBufferizationOptions>
options;
216 return std::make_unique<OneShotBufferizePass>();
221 return std::make_unique<OneShotBufferizePass>(
options);
237 :
IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
238 worklist(worklist), analysisState(
options), statistics(statistics) {
243 void notifyOperationErased(
Operation *op)
override {
244 erasedOps.insert(op);
246 toMemrefOps.erase(op);
249 void notifyOperationInserted(
Operation *op, InsertPoint previous)
override {
251 if (previous.isSet())
258 if (
auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
259 statistics->numBufferAlloc +=
static_cast<int64_t
>(
264 if (isa<ToMemrefOp>(op)) {
265 toMemrefOps.insert(op);
270 if (isa<ToTensorOp>(op))
278 auto const &
options = analysisState.getOptions();
283 worklist.push_back(op);
316 op->
walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
328 worklist.push_back(op);
335 BufferizationRewriter rewriter(op->
getContext(), erasedOps, toMemrefOps,
336 worklist,
options, statistics);
337 for (
unsigned i = 0; i < worklist.size(); ++i) {
340 if (erasedOps.contains(nextOp))
343 auto bufferizableOp =
options.dynCastBufferizableOp(nextOp);
350 if (!bufferizableOp.supportsUnstructuredControlFlow())
352 if (r.getBlocks().size() > 1)
354 "op or BufferizableOpInterface implementation does not support "
355 "unstructured control flow, but at least one region has multiple "
359 LLVM_DEBUG(llvm::dbgs()
360 <<
"//===-------------------------------------------===//\n"
361 <<
"IR after bufferizing: " << nextOp->
getName() <<
"\n");
362 rewriter.setInsertionPoint(nextOp);
363 if (failed(bufferizableOp.bufferize(rewriter,
options))) {
364 LLVM_DEBUG(llvm::dbgs()
365 <<
"failed to bufferize\n"
366 <<
"//===-------------------------------------------===//\n");
367 return nextOp->
emitError(
"failed to bufferize op");
369 LLVM_DEBUG(llvm::dbgs()
371 <<
"\n//===-------------------------------------------===//\n");
375 if (erasedOps.contains(op))
380 rewriter.setInsertionPoint(op);
382 rewriter, cast<ToMemrefOp>(op),
options);
387 if (toTensorOp->getUses().empty()) {
388 rewriter.eraseOp(toTensorOp);
401 if (erasedOps.contains(op))
414 if (isa<ToTensorOp, ToMemrefOp>(op))
416 return op->
emitError(
"op was not bufferized");
433 auto tensorType = dyn_cast<TensorType>(bbArg.getType());
435 newTypes.push_back(bbArg.getType());
439 FailureOr<BaseMemRefType> memrefType =
441 if (failed(memrefType))
443 newTypes.push_back(*memrefType);
447 for (
auto [bbArg, type] : llvm::zip(block->
getArguments(), newTypes)) {
448 if (bbArg.getType() == type)
454 bbArgUses.push_back(&use);
461 if (!bbArgUses.empty()) {
463 rewriter.
create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
465 use->set(toTensorOp);
471 auto branchOp = dyn_cast<BranchOpInterface>(op);
473 return op->
emitOpError(
"cannot bufferize ops with block references that "
474 "do not implement BranchOpInterface");
477 assert(it != op->
getSuccessors().end() &&
"could find successor");
478 int64_t successorIdx = std::distance(op->
getSuccessors().begin(), it);
482 for (
auto [operand, type] :
484 if (operand.getType() == type) {
486 newOperands.push_back(operand);
489 FailureOr<BaseMemRefType> operandBufferType =
491 if (failed(operandBufferType))
494 Value bufferizedOperand = rewriter.
create<bufferization::ToMemrefOp>(
495 operand.getLoc(), *operandBufferType, operand);
498 if (type != *operandBufferType)
499 bufferizedOperand = rewriter.
create<memref::CastOp>(
500 operand.getLoc(), type, bufferizedOperand);
501 newOperands.push_back(bufferizedOperand);
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
user_range getUsers() const
Returns a range of all users.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
MLIRContext is the top-level object for a collection of MLIR operations.
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
SuccessorRange getSuccessors()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
AnalysisState provides a variety of helper functions for dealing with tensor values.
void allowOperation()
Allow the given ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
std::unique_ptr< Pass > createOneShotBufferizePass()
Create a pass that bufferizes all ops that implement BufferizableOpInterface with One-Shot Bufferize.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
bool checkParallelRegions
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
DefaultMemorySpaceFn defaultMemorySpaceFn
Bufferization statistics for debugging.
int64_t numTensorOutOfPlace
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
@ BottomUpFromTerminators
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.