28 namespace bufferization {
29 #define GEN_PASS_DEF_FINALIZINGBUFFERIZE
30 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
31 #define GEN_PASS_DEF_ONESHOTBUFFERIZE
32 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
36 #define DEBUG_TYPE "bufferize"
47 assert(inputs.size() == 1);
48 assert(isa<BaseMemRefType>(inputs[0].getType()));
49 return builder.
create<bufferization::ToTensorOp>(loc, type, inputs[0]);
68 assert(inputs.size() == 1 &&
"expected exactly one input");
70 if (
auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
72 assert(inputType != type &&
"expected different types");
74 auto rankedDestType = dyn_cast<MemRefType>(type);
84 if (isa<TensorType>(inputs[0].getType())) {
86 return builder.
create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
89 llvm_unreachable(
"only tensor/memref input types supported");
95 target.
addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
101 class BufferizeToTensorOp
106 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
108 rewriter.
replaceOp(op, adaptor.getMemref());
117 class BufferizeToMemrefOp
122 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
124 rewriter.
replaceOp(op, adaptor.getTensor());
132 patterns.
add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
137 struct FinalizingBufferizePass
138 :
public bufferization::impl::FinalizingBufferizeBase<
139 FinalizingBufferizePass> {
140 using FinalizingBufferizeBase<
141 FinalizingBufferizePass>::FinalizingBufferizeBase;
143 void runOnOperation()
override {
144 auto func = getOperation();
161 target.markUnknownOpDynamicallyLegal(
169 static LayoutMapOption parseLayoutMapOption(
const std::string &s) {
170 if (s ==
"fully-dynamic-layout-map")
171 return LayoutMapOption::FullyDynamicLayoutMap;
172 if (s ==
"identity-layout-map")
173 return LayoutMapOption::IdentityLayoutMap;
174 if (s ==
"infer-layout-map")
175 return LayoutMapOption::InferLayoutMap;
176 llvm_unreachable(
"invalid layout map option");
180 parseHeuristicOption(
const std::string &s) {
181 if (s ==
"bottom-up")
185 llvm_unreachable(
"invalid analysisheuristic option");
188 struct OneShotBufferizePass
189 :
public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
190 OneShotBufferizePass() =
default;
197 .
insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
200 void runOnOperation()
override {
212 parseLayoutMapOption(functionBoundaryTypeConversion));
213 if (mustInferMemorySpace)
221 LayoutMapOption unknownTypeConversionOption =
222 parseLayoutMapOption(unknownTypeConversion);
223 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
225 "Invalid option: 'infer-layout-map' is not a valid value for "
226 "'unknown-type-conversion'");
227 return signalPassFailure();
231 auto tensorType = cast<TensorType>(value.
getType());
232 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
234 tensorType, memorySpace);
235 assert(unknownTypeConversionOption ==
236 LayoutMapOption::FullyDynamicLayoutMap &&
237 "invalid layout map option");
245 if (this->dialectFilter.hasValue())
246 return llvm::is_contained(this->dialectFilter,
262 "Invalid option: 'copy-before-write' cannot be used with "
263 "'test-analysis-only'");
264 return signalPassFailure();
270 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
271 return signalPassFailure();
277 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
278 return signalPassFailure();
282 ModuleOp moduleOp = getOperation();
291 "Invalid option: 'no-analysis-func-filter' requires "
292 "'bufferize-function-boundaries'");
293 return signalPassFailure();
308 std::optional<OneShotBufferizationOptions>
options;
313 struct BufferizationBufferizePass
314 :
public bufferization::impl::BufferizationBufferizeBase<
315 BufferizationBufferizePass> {
316 void runOnOperation()
override {
318 options.opFilter.allowDialect<BufferizationDialect>();
326 .
insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
332 return std::make_unique<BufferizationBufferizePass>();
336 return std::make_unique<OneShotBufferizePass>();
341 return std::make_unique<OneShotBufferizePass>(
options);
344 std::unique_ptr<OperationPass<func::FuncOp>>
346 return std::make_unique<FinalizingBufferizePass>();
358 return any_of(r.getBlocks(), [](Block &b) {
359 return any_of(b.getArguments(), [](BlockArgument bbArg) {
360 return isaTensor(bbArg.getType());
364 if (hasTensorBlockArgument)
367 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
368 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
369 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
370 return hasTensorArg || hasTensorResult;
375 return hasTensorResult || hasTensorOperand;
387 :
IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
388 worklist(worklist), analysisState(
options), statistics(statistics) {
393 void notifyOperationRemoved(
Operation *op)
override {
394 erasedOps.insert(op);
396 toMemrefOps.erase(op);
399 void notifyOperationInserted(
Operation *op)
override {
404 if (
auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
410 if (isa<ToMemrefOp>(op)) {
411 toMemrefOps.insert(op);
416 if (isa<ToTensorOp>(op))
424 auto const &
options = analysisState.getOptions();
429 worklist.push_back(op);
462 op->
walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
474 worklist.push_back(op);
481 BufferizationRewriter rewriter(op->
getContext(), erasedOps, toMemrefOps,
482 worklist,
options, statistics);
483 for (
unsigned i = 0; i < worklist.size(); ++i) {
486 if (erasedOps.contains(nextOp))
489 auto bufferizableOp =
options.dynCastBufferizableOp(nextOp);
492 if (!
options.isOpAllowed(nextOp))
498 if (!bufferizableOp.supportsUnstructuredControlFlow())
500 if (r.getBlocks().size() > 1)
502 "op or BufferizableOpInterface implementation does not support "
503 "unstructured control flow, but at least one region has multiple "
507 LLVM_DEBUG(llvm::dbgs()
508 <<
"//===-------------------------------------------===//\n"
509 <<
"IR after bufferizing: " << nextOp->
getName() <<
"\n");
510 rewriter.setInsertionPoint(nextOp);
512 LLVM_DEBUG(llvm::dbgs()
513 <<
"failed to bufferize\n"
514 <<
"//===-------------------------------------------===//\n");
515 return nextOp->
emitError(
"failed to bufferize op");
517 LLVM_DEBUG(llvm::dbgs()
519 <<
"\n//===-------------------------------------------===//\n");
524 rewriter.setInsertionPoint(op);
526 cast<ToMemrefOp>(op));
531 if (toTensorOp->getUses().empty()) {
532 rewriter.eraseOp(toTensorOp);
545 if (erasedOps.contains(op))
558 if (isa<ToTensorOp, ToMemrefOp>(op))
560 return op->
emitError(
"op was not bufferized");
577 auto tensorType = dyn_cast<TensorType>(bbArg.getType());
579 newTypes.push_back(bbArg.getType());
587 newTypes.push_back(*memrefType);
591 for (
auto [bbArg, type] : llvm::zip(block->
getArguments(), newTypes)) {
592 if (bbArg.getType() == type)
598 bbArgUses.push_back(&use);
605 if (!bbArgUses.empty()) {
607 rewriter.
create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
609 use->set(toTensorOp);
615 auto branchOp = dyn_cast<BranchOpInterface>(op);
617 return op->
emitOpError(
"cannot bufferize ops with block references that "
618 "do not implement BranchOpInterface");
621 assert(it != op->
getSuccessors().end() &&
"could find successor");
622 int64_t successorIdx = std::distance(op->
getSuccessors().begin(), it);
626 for (
auto [operand, type] :
628 if (operand.getType() == type) {
630 newOperands.push_back(operand);
635 if (
failed(operandBufferType))
638 Value bufferizedOperand = rewriter.
create<bufferization::ToMemrefOp>(
639 operand.getLoc(), *operandBufferType, operand);
642 if (type != *operandBufferType)
643 bufferizedOperand = rewriter.
create<memref::CastOp>(
644 operand.getLoc(), type, bufferizedOperand);
645 newOperands.push_back(bufferizedOperand);
655 options.allowUnknownOps =
true;
656 options.copyBeforeWrite =
true;
657 options.enforceAliasingInvariants =
false;
661 cast<TensorType>(value.
getType()), memorySpace);
663 options.opFilter.allowDialect<BufferizationDialect>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
StringRef getNamespace() const
This class provides support for representing a failure result, or a valid value of type T.
user_range getUsers() const
Returns a range of all users.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
SuccessorRange getSuccessors()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
AnalysisState provides a variety of helper functions for dealing with tensor values.
A helper type converter class that automatically populates the relevant materializations and type con...
BufferizeTypeConverter()
Registers conversions into BufferizeTypeConverter.
void allowOperation()
Allow the given ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
void populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns)
Populate patterns to eliminate bufferize materializations.
LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
void populateBufferizeMaterializationLegality(ConversionTarget &target)
Marks ops used by bufferization for type conversion materializations as "legal" in the given Conversi...
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref)
Try to fold to_memref(to_tensor(x)).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
std::unique_ptr< OperationPass< func::FuncOp > > createFinalizingBufferizePass()
Creates a pass that finalizes a partial bufferization by removing remaining bufferization....
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
std::unique_ptr< Pass > createOneShotBufferizePass()
Create a pass that bufferizes all ops that implement BufferizableOpInterface with One-Shot Bufferize.
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BufferizationOptions getPartialBufferizationOptions()
Return BufferizationOptions such that the bufferizeOp behaves like the old (deprecated) partial,...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
std::unique_ptr< Pass > createBufferizationBufferizePass()
Create a pass that bufferizes ops from the bufferization dialect.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
std::optional< Attribute > defaultMemorySpace
The default memory space that should be used when it cannot be inferred from the context.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
int64_t numTensorOutOfPlace
Options for analysis-enabled bufferization.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.