28 namespace bufferization {
29 #define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
34 #define DEBUG_TYPE "bufferize"
42 parseHeuristicOption(
const std::string &s) {
47 if (s ==
"bottom-up-from-terminators")
52 llvm_unreachable(
"invalid analysisheuristic option");
55 struct OneShotBufferizePass
56 :
public bufferization::impl::OneShotBufferizePassBase<
57 OneShotBufferizePass> {
60 void runOnOperation()
override {
73 if (mustInferMemorySpace && useEncodingForMemorySpace) {
75 <<
"only one of 'must-infer-memory-space' and "
76 "'use-encoding-for-memory-space' are allowed in "
78 return signalPassFailure();
81 if (mustInferMemorySpace) {
88 if (useEncodingForMemorySpace) {
91 if (
auto rtt = dyn_cast<RankedTensorType>(t))
92 return rtt.getEncoding();
105 LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
106 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
108 "Invalid option: 'infer-layout-map' is not a valid value for "
109 "'unknown-type-conversion'");
110 return signalPassFailure();
114 auto tensorType = cast<TensorType>(value.
getType());
115 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
117 tensorType, memorySpace);
118 assert(unknownTypeConversionOption ==
119 LayoutMapOption::FullyDynamicLayoutMap &&
120 "invalid layout map option");
128 if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
129 return llvm::is_contained(this->dialectFilter,
130 op->getDialect()->getNamespace());
145 "Invalid option: 'copy-before-write' cannot be used with "
146 "'test-analysis-only'");
147 return signalPassFailure();
153 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
154 return signalPassFailure();
160 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
161 return signalPassFailure();
166 ModuleOp moduleOp = getOperation();
176 "Invalid option: 'no-analysis-func-filter' requires "
177 "'bufferize-function-boundaries'");
178 return signalPassFailure();
193 std::optional<OneShotBufferizationOptions>
options;
210 :
IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
211 worklist(worklist), analysisState(
options), statistics(statistics) {
216 void notifyOperationErased(
Operation *op)
override {
217 erasedOps.insert(op);
219 toBufferOps.erase(op);
222 void notifyOperationInserted(
Operation *op, InsertPoint previous)
override {
224 if (previous.isSet())
231 if (
auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
232 statistics->numBufferAlloc +=
static_cast<int64_t
>(
237 if (isa<ToBufferOp>(op)) {
238 toBufferOps.insert(op);
243 if (isa<ToTensorOp>(op))
251 auto const &
options = analysisState.getOptions();
256 worklist.push_back(op);
290 op->
walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
302 worklist.push_back(op);
309 BufferizationRewriter rewriter(op->
getContext(), erasedOps, toBufferOps,
310 worklist,
options, statistics);
311 for (
unsigned i = 0; i < worklist.size(); ++i) {
314 if (erasedOps.contains(nextOp))
317 auto bufferizableOp =
options.dynCastBufferizableOp(nextOp);
324 if (!bufferizableOp.supportsUnstructuredControlFlow())
326 if (r.getBlocks().size() > 1)
328 "op or BufferizableOpInterface implementation does not support "
329 "unstructured control flow, but at least one region has multiple "
333 LLVM_DEBUG(llvm::dbgs()
334 <<
"//===-------------------------------------------===//\n"
335 <<
"IR after bufferizing: " << nextOp->
getName() <<
"\n");
336 rewriter.setInsertionPoint(nextOp);
338 bufferizableOp.bufferize(rewriter,
options, bufferizationState))) {
339 LLVM_DEBUG(llvm::dbgs()
340 <<
"failed to bufferize\n"
341 <<
"//===-------------------------------------------===//\n");
342 return nextOp->
emitError(
"failed to bufferize op");
344 LLVM_DEBUG(llvm::dbgs()
346 <<
"\n//===-------------------------------------------===//\n");
350 if (erasedOps.contains(op))
355 rewriter.setInsertionPoint(op);
357 rewriter, cast<ToBufferOp>(op),
options);
362 if (toTensorOp->getUses().empty()) {
363 rewriter.eraseOp(toTensorOp);
376 if (erasedOps.contains(op))
389 if (isa<ToTensorOp, ToBufferOp>(op))
391 return op->
emitError(
"op was not bufferized");
409 auto tensorType = dyn_cast<TensorType>(bbArg.getType());
411 newTypes.push_back(bbArg.getType());
415 FailureOr<BaseMemRefType> memrefType =
417 if (failed(memrefType))
419 newTypes.push_back(*memrefType);
423 for (
auto [bbArg, type] : llvm::zip(block->
getArguments(), newTypes)) {
424 if (bbArg.getType() == type)
430 bbArgUses.push_back(&use);
432 Type tensorType = bbArg.getType();
438 if (!bbArgUses.empty()) {
439 Value toTensorOp = rewriter.
create<bufferization::ToTensorOp>(
440 bbArg.getLoc(), tensorType, bbArg);
442 use->set(toTensorOp);
448 auto branchOp = dyn_cast<BranchOpInterface>(op);
450 return op->
emitOpError(
"cannot bufferize ops with block references that "
451 "do not implement BranchOpInterface");
454 assert(it != op->
getSuccessors().end() &&
"could find successor");
455 int64_t successorIdx = std::distance(op->
getSuccessors().begin(), it);
459 for (
auto [operand, type] :
461 if (operand.getType() == type) {
463 newOperands.push_back(operand);
466 FailureOr<BaseMemRefType> operandBufferType =
468 if (failed(operandBufferType))
471 Value bufferizedOperand = rewriter.
create<bufferization::ToBufferOp>(
472 operand.getLoc(), *operandBufferType, operand);
475 if (type != *operandBufferType)
476 bufferizedOperand = rewriter.
create<memref::CastOp>(
477 operand.getLoc(), type, bufferizedOperand);
478 newOperands.push_back(bufferizedOperand);
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
user_range getUsers() const
Returns a range of all users.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
MLIRContext is the top-level object for a collection of MLIR operations.
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
SuccessorRange getSuccessors()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
AnalysisState provides a variety of helper functions for dealing with tensor values.
BufferizationState provides information about the state of the IR during the bufferization process.
void allowOperation()
Allow the given ops.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
bool checkParallelRegions
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
DefaultMemorySpaceFn defaultMemorySpaceFn
Bufferization statistics for debugging.
int64_t numTensorOutOfPlace
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
@ BottomUpFromTerminators
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.