MLIR 22.0.0git
TensorCopyInsertion.cpp
Go to the documentation of this file.
1//===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16
17using namespace mlir;
18using namespace mlir::bufferization;
19
22 const BufferizationState &bufferizationState,
23 BufferizationStatistics *statistics) {
24 OneShotAnalysisState analysisState(op, options);
25 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
26 // analysis depending on whether function boundary bufferization is enabled or
27 // not.
28 if (options.bufferizeFunctionBoundaries) {
29 if (failed(analyzeModuleOp(op, analysisState, statistics)))
30 return failure();
31 } else {
32 if (failed(analyzeOp(op, analysisState, statistics)))
33 return failure();
34 }
35
36 if (options.testAnalysisOnly)
37 return success();
38
39 return insertTensorCopies(op, analysisState, bufferizationState);
40}
41
43 Operation *op, const AnalysisState &analysisState,
44 const BufferizationState &bufferizationState) {
45 IRRewriter rewriter(op->getContext());
46
47 // It may be more efficient to walk in pre-order here, but the current
48 // implementation visits regions of ops even if they are not allowed or
49 // bufferizable, and existing tests rely on this behavior.
50 // For now, only exclude nested operations if they are in a different symbol
51 // table scope.
52 WalkResult result = op->walk([&](Operation *nestedOp) {
53 if (op->hasTrait<OpTrait::SymbolTable>() &&
55 return WalkResult::skip();
56
57 auto bufferizableOp =
58 analysisState.getOptions().dynCastBufferizableOp(nestedOp);
59 if (!bufferizableOp)
60 return WalkResult::skip();
61
62 // Find inplacability conflicts and resolve them. (Typically with explicit
63 // tensor copies in the form of AllocTensorOps.)
64 rewriter.setInsertionPoint(nestedOp);
65 if (failed(bufferizableOp.resolveConflicts(rewriter, analysisState,
66 bufferizationState)))
67 return WalkResult::interrupt();
68
69 return WalkResult::advance();
70 });
71
72 return failure(result.wasInterrupted());
73}
return success()
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
State for analysis-enabled bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
Include the generated interface declarations.
Bufferization statistics for debugging.
Definition Bufferize.h:35
Options for analysis-enabled bufferization.