23#include "llvm/Support/DebugLog.h"
28#define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
29#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
33#define DEBUG_TYPE "bufferize"
41parseHeuristicOption(
const std::string &s) {
46 if (s ==
"bottom-up-from-terminators")
47 return OneShotBufferizationOptions::AnalysisHeuristic::
48 BottomUpFromTerminators;
51 llvm_unreachable(
"invalid analysisheuristic option");
54struct OneShotBufferizePass
55 :
public bufferization::impl::OneShotBufferizePassBase<
56 OneShotBufferizePass> {
59 void runOnOperation()
override {
65 opt.allowUnknownOps = allowUnknownOps;
68 opt.copyBeforeWrite = copyBeforeWrite;
70 opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
72 if (mustInferMemorySpace && useEncodingForMemorySpace) {
74 <<
"only one of 'must-infer-memory-space' and "
75 "'use-encoding-for-memory-space' are allowed in "
77 return signalPassFailure();
80 if (mustInferMemorySpace) {
81 opt.defaultMemorySpaceFn =
82 [](TensorLikeType t) -> std::optional<Attribute> {
87 if (useEncodingForMemorySpace) {
88 opt.defaultMemorySpaceFn =
89 [](TensorLikeType t) -> std::optional<Attribute> {
90 if (
auto rtt = dyn_cast<RankedTensorType>(t))
91 return rtt.getEncoding();
96 opt.printConflicts = printConflicts;
97 opt.bufferAlignment = bufferAlignment;
98 opt.testAnalysisOnly = testAnalysisOnly;
99 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
100 opt.checkParallelRegions = checkParallelRegions;
104 LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
105 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
107 "Invalid option: 'infer-layout-map' is not a valid value for "
108 "'unknown-type-conversion'");
109 return signalPassFailure();
111 opt.unknownTypeConverterFn = [=](TensorLikeType type,
114 const auto tensorType = cast<TensorType>(type);
115 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116 return cast<bufferization::BufferLikeType>(
117 bufferization::getMemRefTypeWithStaticIdentityLayout(
118 tensorType, memorySpace));
119 assert(unknownTypeConversionOption ==
120 LayoutMapOption::FullyDynamicLayoutMap &&
121 "invalid layout map option");
122 return cast<bufferization::BufferLikeType>(
123 bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
128 OpFilter::Entry::FilterFn filterFn = [&](
Operation *op) {
130 if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
131 return llvm::is_contained(this->dialectFilter,
132 op->getDialect()->getNamespace());
136 opt.opFilter.allowOperation(filterFn);
141 if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
147 "Invalid option: 'copy-before-write' cannot be used with "
148 "'test-analysis-only'");
149 return signalPassFailure();
152 if (opt.printConflicts && !opt.testAnalysisOnly) {
155 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
156 return signalPassFailure();
162 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
163 return signalPassFailure();
166 BufferizationState state;
168 ModuleOp moduleOp = getOperation();
169 if (opt.bufferizeFunctionBoundaries) {
178 "Invalid option: 'no-analysis-func-filter' requires "
179 "'bufferize-function-boundaries'");
180 return signalPassFailure();
195 std::optional<OneShotBufferizationOptions>
options;
209 SmallVector<Operation *> &worklist,
210 const BufferizationOptions &
options,
211 BufferizationStatistics *statistics)
212 : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
213 worklist(worklist), analysisState(
options), statistics(statistics) {
218 void notifyOperationErased(Operation *op)
override {
219 erasedOps.insert(op);
221 toBufferOps.erase(op);
224 void notifyOperationInserted(Operation *op, InsertPoint previous)
override {
226 if (previous.isSet())
233 if (
auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
234 statistics->numBufferAlloc +=
static_cast<int64_t
>(
235 sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
239 if (isa<ToBufferOp>(op)) {
240 toBufferOps.insert(op);
245 if (isa<ToTensorOp>(op))
249 if (!hasTensorSemantics(op))
253 auto const &
options = analysisState.getOptions();
258 worklist.push_back(op);
269 SmallVector<Operation *> &worklist;
273 const AnalysisState analysisState;
276 BufferizationStatistics *statistics;
282 BufferizationState &bufferizationState,
292 op->
walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
303 if (
options.isOpAllowed(op) && hasTensorSemantics(op))
304 worklist.push_back(op);
311 BufferizationRewriter rewriter(op->
getContext(), erasedOps, toBufferOps,
312 worklist,
options, statistics);
313 for (
unsigned i = 0; i < worklist.size(); ++i) {
316 if (erasedOps.contains(nextOp))
319 auto bufferizableOp =
options.dynCastBufferizableOp(nextOp);
323 if (!hasTensorSemantics(nextOp))
326 if (!bufferizableOp.supportsUnstructuredControlFlow())
328 if (r.getBlocks().size() > 1)
330 "op or BufferizableOpInterface implementation does not support "
331 "unstructured control flow, but at least one region has multiple "
335 LDBG(3) <<
"//===-------------------------------------------===//\n"
336 <<
"IR after bufferizing: " << nextOp->
getName();
337 rewriter.setInsertionPoint(nextOp);
339 bufferizableOp.bufferize(rewriter,
options, bufferizationState))) {
340 LDBG(2) <<
"failed to bufferize\n"
341 <<
"//===-------------------------------------------===//";
342 return nextOp->
emitError(
"failed to bufferize op");
344 LDBG(3) << *op <<
"\n//===-------------------------------------------===//";
348 if (erasedOps.contains(op))
356 for (
Operation *op : toBufferOpsSnapshot) {
357 if (erasedOps.contains(op))
359 rewriter.setInsertionPoint(op);
361 rewriter, cast<ToBufferOp>(op),
options);
366 if (toTensorOp->getUses().empty()) {
367 rewriter.eraseOp(toTensorOp);
380 if (erasedOps.contains(op))
384 if (!hasTensorSemantics(op))
393 if (isa<ToTensorOp, ToBufferOp>(op))
395 return op->
emitError(
"op was not bufferized");
404 BufferizationState &state) {
413 auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
415 newTypes.push_back(bbArg.getType());
419 FailureOr<BufferLikeType> bufferType =
420 bufferization::getBufferType(bbArg,
options, state);
421 if (failed(bufferType))
423 newTypes.push_back(*bufferType);
427 for (
auto [bbArg, type] : llvm::zip(block->
getArguments(), newTypes)) {
428 if (bbArg.getType() == type)
434 bbArgUses.push_back(&use);
436 Type tensorType = bbArg.getType();
442 if (!bbArgUses.empty()) {
443 Value toTensorOp = bufferization::ToTensorOp::create(
444 rewriter, bbArg.getLoc(), tensorType, bbArg);
446 use->set(toTensorOp);
452 auto branchOp = dyn_cast<BranchOpInterface>(op);
454 return op->
emitOpError(
"cannot bufferize ops with block references that "
455 "do not implement BranchOpInterface");
458 assert(it != op->
getSuccessors().end() &&
"could find successor");
463 for (
auto [operand, type] :
465 if (operand.getType() == type) {
467 newOperands.push_back(operand);
470 FailureOr<BufferLikeType> operandBufferType =
471 bufferization::getBufferType(operand,
options, state);
472 if (failed(operandBufferType))
475 Value bufferizedOperand = bufferization::ToBufferOp::create(
476 rewriter, operand.getLoc(), *operandBufferType, operand);
479 if (type != *operandBufferType)
480 bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(),
481 type, bufferizedOperand);
482 newOperands.push_back(bufferizedOperand);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
user_range getUsers() const
Returns a range of all users.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
SuccessorRange getSuccessors()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Bufferization statistics for debugging.
int64_t numTensorOutOfPlace
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.