57 #include "llvm/ADT/DenseSet.h"
58 #include "llvm/ADT/SetVector.h"
64 #define DEBUG_TYPE "one-shot-analysis"
90 cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
94 if (isa<TensorType>(opOperand.
get().
getType()))
99 OpBuilder(op).getStrArrayAttr(inPlaceVector));
112 if (isa<TensorType>(v.getType()))
115 for (
Block &b : r.getBlocks())
116 for (
auto bbArg : b.getArguments())
117 if (isa<TensorType>(bbArg.getType()))
122 op->
walk([&](BufferizableOpInterface bufferizableOp) {
125 for (
OpOperand &opOperand : bufferizableOp->getOpOperands())
126 if (isa<TensorType>(opOperand.get().getType()))
127 if (bufferizableOp.mustBufferizeInPlace(opOperand, *
this))
135 auto leaderIt = equivalentInfo.findLeader(v);
136 for (
auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
144 auto leaderIt = aliasInfo.findLeader(v);
145 for (
auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
152 return equivalentInfo.isEquivalent(v1, v2);
157 return aliasInfo.isEquivalent(v1, v2);
161 if (inplaceBufferized.contains(&operand))
163 inplaceBufferized.insert(&operand);
165 aliasInfo.unionSets(alias.opResult, operand.
get());
166 ++statNumTensorInPlace;
170 assert(!inplaceBufferized.contains(&operand) &&
171 "OpOperand was already decided to bufferize inplace");
172 ++statNumTensorOutOfPlace;
177 equivalentInfo.insert(v);
190 Value returnVal = returnValOperand.get();
192 if (!isa<TensorType>(returnVal.getType()))
197 applyOnAliases(returnVal, [&](Value v) {
198 if (auto bbArg = dyn_cast<BlockArgument>(v)) {
199 if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
200 yieldedTensors.insert(bbArg);
203 Operation *definingOp = v.getDefiningOp();
204 if (definingOp->getParentOp() == returnOp->getParentOp())
205 yieldedTensors.insert(v);
213 void OneShotAnalysisState::gatherUndefinedTensorUses(
Operation *op) {
216 auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
222 if (!isa<TensorType>(opResult.getType()))
227 if (findDefinitionsCached(opResult).empty())
228 for (
OpOperand &use : opResult.getUses())
229 undefinedTensorUses.insert(&use);
236 bool OneShotAnalysisState::hasUndefinedContents(
OpOperand *opOperand)
const {
237 return undefinedTensorUses.contains(opOperand);
240 bool OneShotAnalysisState::isInPlace(
OpOperand &opOperand)
const {
241 return inplaceBufferized.contains(&opOperand);
244 bool OneShotAnalysisState::isTensorYielded(
Value tensor)
const {
245 return yieldedTensors.contains(tensor);
248 bool OneShotAnalysisState::isValueWritten(
Value value)
const {
249 bool isWritten =
false;
250 applyOnAliases(value, [&](
Value val) {
252 if (isInPlace(use) && bufferizesToMemoryWrite(use))
258 bool OneShotAnalysisState::isWritable(
Value value)
const {
260 if (
auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
261 return bufferizableOp.isWritable(value, *
this);
264 if (
auto bbArg = dyn_cast<BlockArgument>(value))
265 if (
auto bufferizableOp =
266 getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
267 return bufferizableOp.isWritable(bbArg, *
this);
273 void OneShotAnalysisState::unionAliasSets(
Value v1,
Value v2) {
274 aliasInfo.unionSets(v1, v2);
277 void OneShotAnalysisState::unionEquivalenceClasses(
Value v1,
Value v2) {
278 equivalentInfo.unionSets(v1, v2);
281 OneShotAnalysisState::Extension::~Extension() =
default;
390 for (
Value def : definitions) {
403 if (nextRegion == rDef)
405 assert(nextRegion &&
"expected to find another repetitive region");
419 static uint64_t counter = 0;
424 std::string
id =
"C_" + std::to_string(counter++);
426 std::string conflictingWriteAttr =
432 std::string readAttr =
436 if (
auto opResult = dyn_cast<OpResult>(definition)) {
437 std::string defAttr =
438 id +
"[DEF: result " + std::to_string(opResult.getResultNumber()) +
"]";
439 opResult.getDefiningOp()->setAttr(defAttr, b.
getUnitAttr());
441 auto bbArg = cast<BlockArgument>(definition);
442 std::string defAttr =
443 id +
"[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) +
"]";
444 bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.
getUnitAttr());
463 Operation *readingOp = uRead->getOwner();
464 LLVM_DEBUG(llvm::dbgs() <<
"\n- check conflict:\n");
465 LLVM_DEBUG(llvm::dbgs() <<
" uRead = operand " << uRead->getOperandNumber()
466 <<
" of " << *readingOp <<
"\n");
480 if (definitions.empty()) {
482 LLVM_DEBUG(llvm::dbgs()
483 <<
" no conflict: read value has no definitions\n");
489 for (
OpOperand *uConflictingWrite : usesWrite) {
490 LLVM_DEBUG(llvm::dbgs() <<
" unConflictingWrite = operand "
491 << uConflictingWrite->getOperandNumber() <<
" of "
492 << *uConflictingWrite->getOwner() <<
"\n");
498 LLVM_DEBUG(llvm::dbgs() <<
"\n- useDominance = " << useDominance <<
"\n");
502 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
513 if (
happensBefore(readingOp, conflictingWritingOp, domInfo)) {
514 LLVM_DEBUG(llvm::dbgs()
515 <<
" no conflict: read happens before write\n");
526 if (uConflictingWrite == uRead) {
527 LLVM_DEBUG(llvm::dbgs()
528 <<
" no conflict: read and write are same use\n");
538 LLVM_DEBUG(llvm::dbgs() <<
" no conflict: read and write are in "
539 "mutually exclusive regions\n");
545 if (
auto bufferizableOp =
options.dynCastBufferizableOp(readingOp)) {
546 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
547 LLVM_DEBUG(llvm::dbgs()
548 <<
" no conflict: op interace of reading op says 'no'\n");
553 if (conflictingWritingOp != readingOp) {
554 if (
auto bufferizableOp =
555 options.dynCastBufferizableOp(conflictingWritingOp)) {
556 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
560 <<
" no conflict: op interace of writing op says 'no'\n");
567 for (
Value definition : definitions) {
568 LLVM_DEBUG(llvm::dbgs() <<
" * definition = " << definition <<
"\n");
571 if (
Operation *defOp = definition.getDefiningOp()) {
574 LLVM_DEBUG(llvm::dbgs()
575 <<
" no conflict: write happens before definition\n");
579 if (defOp->isProperAncestor(conflictingWritingOp)) {
582 <<
" no conflict: write is contained in definition\n");
586 auto bbArg = cast<BlockArgument>(definition);
587 Block *block = bbArg.getOwner();
589 LLVM_DEBUG(llvm::dbgs() <<
" no conflict: definition is bbArg "
590 "and write happens outside of block\n");
602 aliases.
getAliases()[0].opResult == definition) {
603 LLVM_DEBUG(llvm::dbgs()
604 <<
" no conflict: definition and write are same\n");
612 LLVM_DEBUG(llvm::dbgs() <<
" => RaW CONFLICT FOUND\n");
625 for (
auto &use : alias.
getUses())
636 for (
auto &use : alias.
getUses()) {
638 if (state.bufferizesToMemoryRead(use)) {
659 AliasingOpResultList aliases = state.getAliasingOpResults(use);
660 if (llvm::any_of(aliases, [&](AliasingOpResult a) {
661 return state.isValueRead(a.opResult);
710 usesWrite.insert(&operand);
717 static int64_t counter = 0;
719 std::string
id =
"W_" + std::to_string(counter++);
720 if (
auto opResult = dyn_cast<OpResult>(value)) {
721 std::string attr =
id +
"[NOT-WRITABLE: result " +
722 std::to_string(opResult.getResultNumber()) +
"]";
723 opResult.getDefiningOp()->setAttr(attr, b.
getUnitAttr());
725 auto bbArg = cast<BlockArgument>(value);
726 std::string attr =
id +
"[NOT-WRITABLE: bbArg " +
727 std::to_string(bbArg.getArgNumber()) +
"]";
728 bbArg.getOwner()->getParentOp()->setAttr(attr, b.
getUnitAttr());
737 bool checkConsistencyOnly =
false) {
747 foundWrite = !usesWrite.empty();
754 bool foundReadOnly =
false;
755 auto checkReadOnly = [&](
Value v) {
757 foundReadOnly =
true;
766 LLVM_DEBUG(llvm::dbgs() <<
"=> NOT WRITABLE\n");
779 OneShotAnalysisState::findDefinitionsCached(
Value value) {
780 if (!cachedDefinitions.count(value))
782 return cachedDefinitions[value];
792 llvm::dbgs() <<
"//===-------------------------------------------===//\n"
794 <<
" of " << *operand.
getOwner() <<
"\n");
796 bool foundInterference =
800 if (foundInterference)
805 LLVM_DEBUG(llvm::dbgs()
806 <<
"//===-------------------------------------------===//\n");
814 if (isa<TensorType>(opOperand.get().getType()))
824 return hasTensorResult || hasTensorOperand;
833 if (!isa<TensorType>(opResult.getType()))
840 Value firstOperand = aliases.
begin()->opOperand->get();
841 bool allEquivalent =
true;
844 bool isInPlace = state.
isInPlace(*alias.opOperand);
845 Value operand = alias.opOperand->get();
846 if (isEquiv && isInPlace && alias.isDefinite) {
850 allEquivalent =
false;
853 if (!isEquiv || !isInPlace)
854 allEquivalent =
false;
856 allEquivalent =
false;
869 if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
906 std::mt19937 g(
getOptions().analysisFuzzerSeed);
907 llvm::shuffle(ops.begin(), ops.end(), g);
917 }
else if (heuristic ==
923 llvm_unreachable(
"unsupported heuristic");
936 WalkResult walkResult = op->
walk([&](BufferizableOpInterface op) {
938 if (!
options.isOpAllowed(op.getOperation()))
943 if (isa<ToMemrefOp>(op.getOperation())) {
944 op->emitError(
"to_memref ops are not supported by One-Shot Analysis");
945 return WalkResult::interrupt();
951 if (
auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
952 if (!toTensorOp.getRestrict()) {
953 op->emitError(
"to_tensor ops without `restrict` are not supported by "
954 "One-Shot Analysis");
955 return WalkResult::interrupt();
960 if (isa<TensorType>(opOperand.get().getType())) {
961 if (wouldCreateReadAfterWriteInterference(
962 opOperand, domInfo, state,
967 op->emitError(
"input IR has RaW conflict");
968 return WalkResult::interrupt();
976 return success(!walkResult.wasInterrupted());
986 if (isa<TensorType>(opOperand.get().getType()))
998 if (llvm::isa<TensorType>(opResult.getType())) {
999 SmallVector<Attribute> aliases;
1000 state.applyOnAliases(opResult, [&](Value alias) {
1002 llvm::raw_string_ostream stream(buffer);
1003 alias.printAsOperand(stream, asmState);
1004 aliases.push_back(b.getStringAttr(stream.str()));
1006 aliasSets.push_back(b.getArrayAttr(aliases));
1009 if (!aliasSets.empty())
1054 Value returnVal = returnValOperand.get();
1056 if (!isa<TensorType>(returnVal.getType()))
1059 bool foundEquivValue = false;
1060 state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
1061 if (auto bbArg = dyn_cast<BlockArgument>(equivVal)) {
1062 Operation *definingOp = bbArg.getOwner()->getParentOp();
1063 if (definingOp->isProperAncestor(returnOp))
1064 foundEquivValue = true;
1068 Operation *definingOp = equivVal.getDefiningOp();
1069 if (definingOp->getBlock()->findAncestorOpInBlock(
1070 *returnOp->getParentOp()))
1072 if (happensBefore(definingOp, returnOp, domInfo))
1073 foundEquivValue = true;
1078 if (!foundEquivValue)
1080 <<
"operand #" << returnValOperand.getOperandNumber()
1081 <<
" may return/yield a new buffer allocation";
1108 bool failedAnalysis =
false;
1109 if (!
options.allowReturnAllocs)
1120 if (BufferizableOpInterface bufferizableOp =
1121 options.dynCastBufferizableOp(op))
1122 failedAnalysis |=
failed(bufferizableOp.verifyAnalysis(state));
1131 return success(!failedAnalysis);
1139 "invalid combination of bufferization flags");
1140 if (!
options.copyBeforeWrite) {
1148 nullptr, statistics);
static bool hasReadAfterWriteInterference(const DenseSet< OpOperand * > &usesRead, const DenseSet< OpOperand * > &usesWrite, const DominanceInfo &domInfo, OneShotAnalysisState &state)
Given sets of uses and writes, return true if there is a RaW conflict under the assumption that all g...
static void getAliasingReads(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)
static void equivalenceAnalysis(SmallVector< Operation * > &ops, OneShotAnalysisState &state)
Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace)
Mark whether OpOperand will be bufferized inplace.
constexpr StringLiteral kInPlaceOperandsAttrName
Attribute marker to specify op operands that bufferize in-place.
static bool isaTensor(Type t)
static bool hasTensorSemantics(Operation *op)
Return true if the given op has a tensor result or a tensor operand.
static void annotateNonWritableTensor(Value value)
Annotate IR with details about the detected non-writability conflict.
constexpr StringLiteral kAliasSetAttrName
static LogicalResult checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, OneShotAnalysisState &state)
Assert that the current bufferization decisions are consistent.
static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo)
Determine if operand can be bufferized in-place.
static bool happensBefore(Operation *a, Operation *b, const DominanceInfo &domInfo)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static bool wouldCreateWriteToNonWritableBuffer(OpOperand &operand, OneShotAnalysisState &state, bool checkConsistencyOnly=false)
Return true if bufferizing operand inplace would create a write to a non-writable buffer.
static void annotateOpsWithAliasSets(Operation *op, const OneShotAnalysisState &state)
static void annotateOpsWithBufferizationMarkers(Operation *op, const OneShotAnalysisState &state)
Annotate the IR with the result of the analysis. For testing/debugging only.
static bool wouldCreateReadAfterWriteInterference(OpOperand &operand, const DominanceInfo &domInfo, OneShotAnalysisState &state, bool checkConsistencyOnly=false)
Return true if bufferizing operand inplace would create a conflict.
bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, const AnalysisState &state)
Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...
static void getAliasingInplaceWrites(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)
static LogicalResult assertNoAllocsReturned(Operation *op, const OneShotAnalysisState &state)
Assert that every allocation can be deallocated in the same block.
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value definition)
Annotate IR with details about the detected RaW conflict.
static bool isInplaceMemoryWrite(OpOperand &opOperand, const OneShotAnalysisState &state)
Return true if opOperand has been decided to bufferize in-place.
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Base class for generic analysis states.
This class provides management for the lifetime of the state used when printing the IR.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
This class is a general helper class for creating context-global objects like types,...
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
size_t getNumAliases() const
ArrayRef< T > getAliases() const
AnalysisState provides a variety of helper functions for dealing with tensor values.
AliasingOpOperandList getAliasingOpOperands(OpResult result) const
Determine which OpOperand* will alias with result if the op is bufferized in place.
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
AliasingOpResultList getAliasingOpResults(OpOperand &opOperand) const
Determine which OpResult will alias with opOperand if the op is bufferized in place.
SetVector< Value > findDefinitions(Value value) const
Find the values that may define the contents of the given value at runtime.
State for analysis-enabled bufferization.
void bufferizeOutOfPlace(OpOperand &operand)
Mark the given OpOperand as out-of-place.
bool isWritable(Value value) const
Return true if the buffer of the given tensor value is writable.
const SetVector< Value > & findDefinitionsCached(Value value)
Find the definitions of the given tensor value or retrieve them from the cache.
bool isInPlace(OpOperand &opOperand) const override
Return true if the given OpResult has been decided to bufferize inplace.
LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo)
Analyze the given op and its nested ops.
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
void unionEquivalenceClasses(Value v1, Value v2)
Union the equivalence classes of v1 and v2.
void gatherUndefinedTensorUses(Operation *op)
Find all tensor values in the given operation that have undefined contents and store them in undefine...
LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo)
Analyze a single op (without nested ops).
void gatherYieldedTensors(Operation *op)
Find all tensors that are yielded/returned from a block and store them in yieldedTensors.
void applyOnEquivalenceClass(Value v, function_ref< void(Value)> fun) const
Apply fun to all the members of the equivalence class of v.
int64_t getStatNumTensorOutOfPlace() const
void resetCache()
Reset cached data structures.
void bufferizeInPlace(OpOperand &operand)
Mark the given OpOperand as in-place and merge the results' and operand's aliasing sets.
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
OneShotAnalysisState(Operation *op, const OneShotBufferizationOptions &options)
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
void createAliasInfoEntry(Value v)
Add a new entry for v in the aliasInfo and equivalentInfo.
int64_t getStatNumTensorInPlace() const
Operation * getOwner() const
Return the owner of this operand.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, bool copyBeforeWrite=true, const OpFilter *opFilter=nullptr, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)
Assuming that the given region is repetitive, find the next enclosing repetitive region.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isRegionReturnLike(Operation *operation)
Returns true if the given operation is either annotated with the ReturnLike trait or implements the R...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
A maybe aliasing OpOperand.
A maybe aliasing OpResult.
Options for BufferizableOpInterface-based bufferization.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool isOpAllowed(Operation *op) const
Return true if the given op should be bufferized.
Bufferization statistics for debugging.
int64_t numTensorOutOfPlace
Options for analysis-enabled bufferization.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.