27 namespace bufferization {
28 #define GEN_PASS_DEF_FINALIZINGBUFFERIZE
29 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
30 #define GEN_PASS_DEF_ONESHOTBUFFERIZE
31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
35 #define DEBUG_TYPE "bufferize"
46 assert(inputs.size() == 1);
47 assert(inputs[0].getType().isa<BaseMemRefType>());
48 return builder.
create<bufferization::ToTensorOp>(loc, type, inputs[0]);
57 return MemRefType::get(type.getShape(), type.getElementType());
61 return UnrankedMemRefType::get(type.getElementType(), 0);
67 assert(inputs.size() == 1 &&
"expected exactly one input");
69 if (
auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
71 assert(inputType != type &&
"expected different types");
73 auto rankedDestType = type.
dyn_cast<MemRefType>();
83 if (inputs[0].getType().isa<TensorType>()) {
85 return builder.
create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
88 llvm_unreachable(
"only tensor/memref input types supported");
94 target.
addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
100 class BufferizeToTensorOp
105 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
107 rewriter.
replaceOp(op, adaptor.getMemref());
116 class BufferizeToMemrefOp
121 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
123 rewriter.
replaceOp(op, adaptor.getTensor());
131 patterns.
add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
136 struct FinalizingBufferizePass
137 :
public bufferization::impl::FinalizingBufferizeBase<
138 FinalizingBufferizePass> {
139 using FinalizingBufferizeBase<
140 FinalizingBufferizePass>::FinalizingBufferizeBase;
142 void runOnOperation()
override {
143 auto func = getOperation();
144 auto *context = &getContext();
160 target.markUnknownOpDynamicallyLegal(
168 static LayoutMapOption parseLayoutMapOption(
const std::string &s) {
169 if (s ==
"fully-dynamic-layout-map")
170 return LayoutMapOption::FullyDynamicLayoutMap;
171 if (s ==
"identity-layout-map")
172 return LayoutMapOption::IdentityLayoutMap;
173 if (s ==
"infer-layout-map")
174 return LayoutMapOption::InferLayoutMap;
175 llvm_unreachable(
"invalid layout map option");
179 parseHeuristicOption(
const std::string &s) {
180 if (s ==
"bottom-up")
184 llvm_unreachable(
"invalid analysisheuristic option");
187 struct OneShotBufferizePass
188 :
public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
189 OneShotBufferizePass() =
default;
196 .
insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
200 void runOnOperation()
override {
212 parseLayoutMapOption(functionBoundaryTypeConversion);
213 if (mustInferMemorySpace)
221 LayoutMapOption unknownTypeConversionOption =
222 parseLayoutMapOption(unknownTypeConversion);
226 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
228 tensorType, memorySpace);
229 assert(unknownTypeConversionOption ==
230 LayoutMapOption::FullyDynamicLayoutMap &&
231 "invalid layout map option");
239 if (this->dialectFilter.hasValue())
240 return llvm::is_contained(this->dialectFilter,
251 ModuleOp moduleOp = getOperation();
259 "invalid combination of bufferization flags");
279 (void)runPipeline(cleanupPipeline, moduleOp);
283 std::optional<OneShotBufferizationOptions>
options;
288 struct BufferizationBufferizePass
289 :
public bufferization::impl::BufferizationBufferizeBase<
290 BufferizationBufferizePass> {
291 void runOnOperation()
override {
293 options.opFilter.allowDialect<BufferizationDialect>();
301 .
insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
307 return std::make_unique<BufferizationBufferizePass>();
311 return std::make_unique<OneShotBufferizePass>();
316 return std::make_unique<OneShotBufferizePass>(
options);
319 std::unique_ptr<OperationPass<func::FuncOp>>
321 return std::make_unique<FinalizingBufferizePass>();
332 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
333 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
334 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
335 return hasTensorArg || hasTensorResult;
340 return hasTensorResult || hasTensorOperand;
345 class BufferizationRewriter :
public IRRewriter {
353 :
IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
354 worklist(worklist), analysisState(
options), opFilter(opFilter),
355 statistics(statistics) {}
358 void notifyOperationRemoved(
Operation *op)
override {
360 erasedOps.insert(op);
362 toMemrefOps.erase(op);
365 void notifyOperationInserted(
Operation *op)
override {
371 if (
auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op)) {
380 if (isa<ToMemrefOp>(op)) {
381 toMemrefOps.insert(op);
386 if (isa<ToTensorOp>(op))
394 auto const &
options = analysisState.getOptions();
395 if (!
options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
399 worklist.push_back(op);
426 bool copyBeforeWrite,
429 if (copyBeforeWrite) {
437 op->
walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
450 op->
walk([&](func::FuncOp funcOp) {
452 worklist.push_back(funcOp);
456 worklist.push_back(op);
463 BufferizationRewriter rewriter(op->
getContext(), erasedOps, toMemrefOps,
464 worklist,
options, opFilter, statistics);
465 for (
unsigned i = 0; i < worklist.size(); ++i) {
468 if (erasedOps.contains(nextOp))
471 auto bufferizableOp =
options.dynCastBufferizableOp(nextOp);
480 LLVM_DEBUG(llvm::dbgs()
481 <<
"//===-------------------------------------------===//\n"
482 <<
"IR after bufferizing: " << nextOp->
getName() <<
"\n");
483 rewriter.setInsertionPoint(nextOp);
485 LLVM_DEBUG(llvm::dbgs()
486 <<
"failed to bufferize\n"
487 <<
"//===-------------------------------------------===//\n");
488 return nextOp->
emitError(
"failed to bufferize op");
490 LLVM_DEBUG(llvm::dbgs()
492 <<
"\n//===-------------------------------------------===//\n");
497 rewriter.setInsertionPoint(op);
499 cast<ToMemrefOp>(op));
504 if (toTensorOp->getUses().empty()) {
505 rewriter.eraseOp(toTensorOp);
518 if (erasedOps.contains(op))
533 if (isa<ToTensorOp, ToMemrefOp>(op))
535 return op->
emitError(
"op was not bufferized");
543 options.allowUnknownOps =
true;
544 options.createDeallocs =
false;
545 options.enforceAliasingInvariants =
false;
551 options.opFilter.allowDialect<BufferizationDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
StringRef getNamespace() const
This class provides support for representing a failure result, or a valid value of type T.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a pass manager that runs passes on either a specific operation type,...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void notifyOperationRemoved(Operation *op)
This is called on an operation that a rewrite is removing, right before the operation is deleted.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
AnalysisState provides a variety of helper functions for dealing with tensor values.
A helper type converter class that automatically populates the relevant materializations and type con...
BufferizeTypeConverter()
Registers conversions into BufferizeTypeConverter.
bool isOpAllowed(Operation *op) const
Return whether the op is allowed or not.
void allowOperation()
Allow the given ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
void populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns)
Populate patterns to eliminate bufferize materializations.
LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
void populateBufferizeMaterializationLegality(ConversionTarget &target)
Marks ops used by bufferization for type conversion materializations as "legal" in the given Conversi...
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref)
Try to fold to_memref(to_tensor(x)).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, bool copyBeforeWrite=true, const OpFilter *opFilter=nullptr, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
void registerAllocationOpInterfaceExternalModels(DialectRegistry ®istry)
Register external models for AllocationOpInterface.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
std::unique_ptr< OperationPass< func::FuncOp > > createFinalizingBufferizePass()
Creates a pass that finalizes a partial bufferization by removing remaining bufferization....
std::unique_ptr< Pass > createOneShotBufferizePass()
Create a pass that bufferizes all ops that implement BufferizableOpInterface with One-Shot Bufferize.
BufferizationOptions getPartialBufferizationOptions()
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
std::unique_ptr< Pass > createBufferizationBufferizePass()
Create a pass that bufferizes ops from the bufferization dialect.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createCSEPass()
Creates a pass to perform common sub expression elimination.
std::unique_ptr< Pass > createLoopInvariantCodeMotionPass()
Creates a loop invariant code motion pass that hoists loop invariant instructions out of the loop.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
Options for BufferizableOpInterface-based bufferization.
bool createDeallocs
Specifies whether dealloc ops should be generated along with alloc ops.
bool copyBeforeWrite
If set to true, the analysis is skipped.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
LayoutMapOption functionBoundaryTypeConversion
This flag controls buffer types on function signatures.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
std::optional< Attribute > defaultMemorySpace
The default memory space that should be used when it cannot be inferred from the context.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
int64_t numTensorOutOfPlace
Options for analysis-enabled bufferization.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
bool allowReturnAllocs
Specifies whether returning newly allocated memrefs should be allowed.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.