33 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
34 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
41 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
42 "scf.while op bufferization: cast incompatible");
43 return b.
create<memref::CastOp>(buffer.
getLoc(), type, buffer).getResult();
47 struct ConditionOpInterface
48 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
75 auto conditionOp = cast<scf::ConditionOp>(op);
76 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
80 Value value = it.value();
81 if (isa<TensorType>(value.
getType())) {
86 whileOp.getAfterArguments()[it.index()],
options);
89 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
90 newArgs.push_back(buffer);
92 newArgs.push_back(value);
96 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
97 rewriter, op, conditionOp.getCondition(), newArgs);
104 struct ExecuteRegionOpInterface
105 :
public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
106 scf::ExecuteRegionOp> {
114 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
115 size_t resultNum = std::distance(op->
getOpResults().begin(),
118 assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
119 "expected exactly 1 block");
120 auto yieldOp = dyn_cast<scf::YieldOp>(
121 executeRegionOp.getRegion().front().getTerminator());
122 assert(yieldOp &&
"expected scf.yield terminator in scf.execute_region");
128 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
129 assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
130 "only 1 block supported");
132 cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
133 TypeRange newResultTypes(yieldOp.getResults());
137 rewriter.
create<scf::ExecuteRegionOp>(op->
getLoc(), newResultTypes);
143 for (
const auto &it :
llvm::enumerate(executeRegionOp->getResultTypes())) {
144 if (isa<TensorType>(it.value())) {
145 newResults.push_back(rewriter.
create<bufferization::ToTensorOp>(
146 executeRegionOp.getLoc(), newOp->getResult(it.index())));
148 newResults.push_back(newOp->getResult(it.index()));
153 rewriter.
replaceOp(executeRegionOp, newResults);
161 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
169 auto ifOp = cast<scf::IfOp>(op);
170 size_t resultNum = std::distance(op->
getOpResults().begin(),
172 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
173 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
181 auto ifOp = cast<scf::IfOp>(op);
185 for (
Value result : ifOp.getResults()) {
186 if (!isa<TensorType>(result.getType())) {
187 newTypes.push_back(result.getType());
193 newTypes.push_back(*bufferType);
199 rewriter.
create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
203 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
204 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
215 auto ifOp = cast<scf::IfOp>(op);
216 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
217 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
221 auto opResult = cast<OpResult>(value);
225 if (isa<BaseMemRefType>(thenValue.getType())) {
227 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
229 auto maybeBufferType =
231 if (
failed(maybeBufferType))
233 thenBufferType = *maybeBufferType;
235 if (isa<BaseMemRefType>(elseValue.getType())) {
237 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
239 auto maybeBufferType =
241 if (
failed(maybeBufferType))
243 elseBufferType = *maybeBufferType;
247 if (thenBufferType == elseBufferType)
248 return thenBufferType;
252 return op->
emitError(
"inconsistent memory space on then/else branches");
265 if (isa<TensorType>(it.value().getType()))
266 result.insert(it.index());
275 unsigned int minSize =
std::min(bbArgs.size(), yieldedValues.size());
277 for (
unsigned int i = 0; i < minSize; ++i) {
278 if (!isa<TensorType>(bbArgs[i].getType()) ||
279 !isa<TensorType>(yieldedValues[i].getType()))
281 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
294 if (isa<TensorType>(opOperand.get().getType())) {
299 result.push_back(*resultBuffer);
301 result.push_back(opOperand.get());
315 size_t idx = it.index();
316 Value val = it.value();
317 if (tensorIndices.contains(idx)) {
319 rewriter.
create<bufferization::ToTensorOp>(val.
getLoc(), val)
322 result.push_back(val);
345 auto initArgBufferType =
347 if (
failed(initArgBufferType))
360 newFixedTypes[iterArg] = *initArgBufferType;
364 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
366 yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.
getType());
368 auto maybeBufferType =
370 if (
failed(maybeBufferType))
372 yieldedValueBufferType = *maybeBufferType;
376 if (*initArgBufferType == yieldedValueBufferType)
377 return yieldedValueBufferType;
382 auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
384 auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
385 assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
386 "expected same shape");
387 assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
388 "expected same memory space");
391 cast<RankedTensorType>(iterArg.
getType()),
392 yieldedRanked.getMemorySpace());
396 bool mayHaveZeroIterations(scf::ForOp forOp) {
399 if (!lb.has_value() || !ub.has_value())
406 struct ForOpInterface
407 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
411 auto forOp = cast<scf::ForOp>(op);
415 if (mayHaveZeroIterations(forOp))
420 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
431 auto forOp = cast<scf::ForOp>(op);
432 OpResult opResult = forOp.getResultForOpOperand(opOperand);
434 return {{opResult, relation,
442 auto forOp = cast<scf::ForOp>(op);
443 OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
444 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
446 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
447 bool equivalentYield = state.areEquivalentBufferizedValues(
466 auto bufferizableOp = cast<BufferizableOpInterface>(op);
467 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
470 if (!state.getOptions().enforceAliasingInvariants)
480 auto forOp = cast<scf::ForOp>(op);
482 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
492 forOp.getRegionIterArgs(), yieldOp.getResults(), state);
494 for (int64_t idx = 0;
495 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
496 Value value = yieldOp.getResults()[idx];
497 if (!indices.contains(idx) || equivalentYields.contains(idx)) {
498 yieldValues.push_back(value);
503 true, state.getOptions());
506 yieldValues.push_back(*alloc);
510 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
517 auto forOp = cast<scf::ForOp>(op);
519 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
523 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
525 forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
528 resultNum = cast<OpResult>(value).getResultNumber();
533 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
534 Value yieldedValue = yieldOp.getOperand(resultNum);
535 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
536 Value initArg = forOp.getInitArgs()[resultNum];
537 return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue,
543 auto forOp = cast<scf::ForOp>(op);
544 Block *oldLoopBody = &forOp.getLoopBody().
front();
552 getBuffers(rewriter, forOp.getIterOpOperands(),
options);
553 if (
failed(maybeInitArgs))
560 Value initArg = it.value();
561 Value result = forOp->getResult(it.index());
563 if (!isa<TensorType>(result.
getType())) {
564 castedInitArgs.push_back(initArg);
570 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
574 auto newForOp = rewriter.
create<scf::ForOp>(
575 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
576 forOp.getStep(), castedInitArgs);
577 newForOp->
setAttrs(forOp->getAttrs());
578 Block *loopBody = &newForOp.getLoopBody().
front();
584 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
585 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
588 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
610 auto forOp = cast<scf::ForOp>(op);
612 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
614 if (!isa<TensorType>(opResult.
getType()))
620 return yieldOp->emitError()
622 <<
" is not equivalent to the corresponding iter bbArg";
631 struct WhileOpInterface
632 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
648 auto whileOp = cast<scf::WhileOp>(op);
658 OpResult opResult = whileOp->getResult(idx);
660 return {{opResult, relation,
670 auto whileOp = cast<scf::WhileOp>(op);
673 if (resultNumber >= whileOp.getBeforeArguments().size())
676 whileOp.getBeforeArguments()[resultNumber].getType())
679 auto conditionOp = whileOp.getConditionOp();
680 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
681 Value conditionOperand = conditionOp.getArgs()[resultNumber];
682 bool equivCondition =
683 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
685 auto yieldOp = whileOp.getYieldOp();
686 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
687 Value yieldOperand = yieldOp.getOperand(resultNumber);
689 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
708 auto bufferizableOp = cast<BufferizableOpInterface>(op);
709 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
712 if (!state.getOptions().enforceAliasingInvariants)
723 auto whileOp = cast<scf::WhileOp>(op);
724 auto conditionOp = whileOp.getConditionOp();
729 whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
731 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
736 for (int64_t idx = 0;
737 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
738 Value value = conditionOp.getArgs()[idx];
739 if (!isa<TensorType>(value.
getType()) ||
740 (equivalentYieldsAfter.contains(idx) &&
741 equivalentYieldsBefore.contains(idx))) {
742 beforeYieldValues.push_back(value);
747 true, state.getOptions());
750 beforeYieldValues.push_back(*alloc);
753 conditionOp.getArgsMutable().assign(beforeYieldValues);
761 auto whileOp = cast<scf::WhileOp>(op);
763 assert(whileOp.getBefore().getBlocks().size() == 1 &&
764 "regions with multiple blocks not supported");
765 Block *beforeBody = &whileOp.getBefore().
front();
766 assert(whileOp.getAfter().getBlocks().size() == 1 &&
767 "regions with multiple blocks not supported");
768 Block *afterBody = &whileOp.getAfter().
front();
774 getTensorIndices(whileOp.getAfterArguments());
778 getBuffers(rewriter, whileOp->getOpOperands(),
options);
779 if (
failed(maybeInitArgs))
786 Value initArg = it.value();
787 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
789 if (!isa<TensorType>(beforeArg.
getType())) {
790 castedInitArgs.push_back(initArg);
796 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
801 llvm::map_range(whileOp.getAfterArguments(), [&](
BlockArgument bbArg) {
802 if (!isa<TensorType>(bbArg.getType()))
803 return bbArg.getType();
805 return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
810 TypeRange argsTypesBefore(argsRangeBefore);
811 auto newWhileOp = rewriter.
create<scf::WhileOp>(
812 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
819 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
820 newWhileOp.getBefore().
addArguments(argsTypesBefore, bbArgLocsBefore);
821 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
822 newWhileOp.getAfter().
addArguments(argsTypesAfter, bbArgLocsAfter);
829 rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
830 rewriter.
mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
837 rewriter, newWhileOp.getAfterArguments(), indicesAfter);
838 rewriter.
mergeBlocks(afterBody, newAfterBody, newAfterArgs);
849 auto whileOp = cast<scf::WhileOp>(op);
851 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
854 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
855 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
856 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
857 auto yieldOp = whileOp.getYieldOp();
858 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
859 return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue,
868 if (
auto opResult = dyn_cast<OpResult>(value)) {
870 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
871 &whileOp.getAfter()) {
872 resultNum = cast<BlockArgument>(value).getArgNumber();
874 llvm_unreachable(
"invalid value");
876 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
877 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
879 return cast<BaseMemRefType>(conditionYieldedVal.
getType());
897 auto whileOp = cast<scf::WhileOp>(op);
903 auto conditionOp = whileOp.getConditionOp();
905 Block *block = conditionOp->getBlock();
906 if (!isa<TensorType>(it.value().getType()))
909 !state.areEquivalentBufferizedValues(it.value(),
911 return conditionOp->emitError()
912 <<
"Condition arg #" << it.index()
913 <<
" is not equivalent to the corresponding iter bbArg";
916 auto yieldOp = whileOp.getYieldOp();
918 Block *block = yieldOp->getBlock();
919 if (!isa<TensorType>(it.value().getType()))
922 !state.areEquivalentBufferizedValues(it.value(),
924 return yieldOp->emitError()
925 <<
"Yield operand #" << it.index()
926 <<
" is not equivalent to the corresponding iter bbArg";
935 struct YieldOpInterface
936 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
950 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
970 auto yieldOp = cast<scf::YieldOp>(op);
971 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
972 yieldOp->getParentOp()))
973 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
977 Value value = it.value();
978 if (isa<TensorType>(value.
getType())) {
982 Value buffer = *maybeBuffer;
984 if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
986 yieldOp->getParentOp()->getResult(it.index()),
options);
989 buffer = castBuffer(rewriter, buffer, *resultType);
990 }
else if (
auto whileOp =
991 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
993 whileOp.getBeforeArguments()[it.index()],
options);
996 buffer = castBuffer(rewriter, buffer, *resultType);
998 newResults.push_back(buffer);
1000 newResults.push_back(value);
1004 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1010 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1011 for (
auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1012 forallOp.getMixedUpperBound())) {
1015 if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1025 struct ForallOpInterface
1026 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1030 auto forallOp = cast<ForallOp>(op);
1035 if (mayHaveZeroIterations(forallOp))
1040 return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1051 auto forallOp = cast<ForallOp>(op);
1064 auto forallOp = cast<ForallOp>(op);
1065 int64_t rank = forallOp.getRank();
1069 for (
Value out : forallOp.getOutputs()) {
1073 buffers.push_back(*buffer);
1078 for (
const auto &it : llvm::zip(
1079 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1081 Value buffer = std::get<1>(it);
1082 Value bufferAsTensor =
1083 rewriter.
create<ToTensorOp>(forallOp.getLoc(), buffer);
1090 ForallOp newForallOp;
1091 newForallOp = rewriter.
create<ForallOp>(
1092 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1093 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1096 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1100 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1101 newForallOp.getBody()->getArguments().end());
1102 replacementBbArgs.append(forallOp.getOutputs().size(),
Value());
1103 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1115 auto forallOp = cast<ForallOp>(op);
1117 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1121 forallOp.getTiedOpOperand(bbArg)->get(),
options, fixedTypes);
1126 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1131 auto forallOp = cast<ForallOp>(op);
1135 for (
auto [lb, ub, step] :
1136 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1137 forallOp.getMixedStep())) {
1150 if (*lbConstant + *stepConstant < *ubConstant)
1158 struct InParallelOpInterface
1159 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1163 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1175 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1176 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1177 ForOp::attachInterface<ForOpInterface>(*ctx);
1178 IfOp::attachInterface<IfOpInterface>(*ctx);
1179 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1180 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1181 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1182 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getOpResults()
unsigned getNumResults()
Return the number of results held by this operation.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, bool escape, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.