22#include "llvm/ADT/SmallVectorExtras.h"
34static Value castBuffer(OpBuilder &
b, Value buffer, Type type) {
41 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
42 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
46 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
47 "scf.while op bufferization: cast incompatible");
48 return memref::CastOp::create(
b, buffer.
getLoc(), type, buffer).getResult();
55static bool doesNotAliasExternalValue(Value value, Region *region,
57 const OneShotAnalysisState &state) {
58 assert(region->
hasOneBlock() &&
"expected region with single block");
61 if (llvm::is_contained(exceptions, alias))
66 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
73struct ConditionOpInterface
74 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
76 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
77 const AnalysisState &state)
const {
81 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
82 const AnalysisState &state)
const {
86 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
87 const AnalysisState &state)
const {
91 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
92 const AnalysisState &state)
const {
99 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
100 const BufferizationOptions &
options,
101 BufferizationState &state)
const {
102 auto conditionOp = cast<scf::ConditionOp>(op);
103 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
105 SmallVector<Value> newArgs;
106 for (
const auto &it : llvm::enumerate(conditionOp.getArgs())) {
107 Value value = it.value();
108 if (isa<TensorLikeType>(value.
getType())) {
109 FailureOr<Value> maybeBuffer =
110 getBuffer(rewriter, value,
options, state);
113 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
114 whileOp.getAfterArguments()[it.index()],
options, state);
117 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
118 newArgs.push_back(buffer);
120 newArgs.push_back(value);
124 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
125 rewriter, op, conditionOp.getCondition(), newArgs);
132static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
134 for (
Block &block : executeRegionOp.getRegion()) {
135 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
146struct ExecuteRegionOpInterface
147 :
public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
148 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
150 static bool supportsUnstructuredControlFlow() {
return true; }
152 bool isWritable(Operation *op, Value value,
153 const AnalysisState &state)
const {
157 LogicalResult verifyAnalysis(Operation *op,
158 const AnalysisState &state)
const {
159 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
161 if (!getUniqueYieldOp(executeRegionOp))
162 return op->
emitOpError(
"op without unique scf.yield is not supported");
166 AliasingOpOperandList
167 getAliasingOpOperands(Operation *op, Value value,
168 const AnalysisState &state)
const {
169 if (
auto bbArg = dyn_cast<BlockArgument>(value))
170 return getAliasingBranchOpOperands(op, bbArg, state);
176 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
178 assert(it != op->
getOpResults().end() &&
"invalid value");
179 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
180 auto yieldOp = getUniqueYieldOp(executeRegionOp);
184 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
187 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188 const BufferizationOptions &
options,
189 BufferizationState &state)
const {
190 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
191 auto yieldOp = getUniqueYieldOp(executeRegionOp);
192 TypeRange newResultTypes(yieldOp.getResults());
195 auto newOp = scf::ExecuteRegionOp::create(
196 rewriter, op->
getLoc(), newResultTypes, executeRegionOp.getNoInline());
197 newOp.getRegion().takeBody(executeRegionOp.getRegion());
200 for (
Block &block : newOp.getRegion())
207 SmallVector<Value> newResults;
208 for (
const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
209 if (isa<TensorType>(it.value())) {
210 newResults.push_back(bufferization::ToTensorOp::create(
211 rewriter, executeRegionOp.getLoc(), it.value(),
212 newOp->getResult(it.index())));
214 newResults.push_back(newOp->getResult(it.index()));
219 rewriter.
replaceOp(executeRegionOp, newResults);
227 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
228 AliasingOpOperandList
229 getAliasingOpOperands(Operation *op, Value value,
230 const AnalysisState &state)
const {
235 auto ifOp = cast<scf::IfOp>(op);
236 size_t resultNum = std::distance(op->
getOpResults().begin(),
238 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
239 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
240 return {{thenOperand, BufferRelation::Equivalent,
false},
241 {elseOperand, BufferRelation::Equivalent,
false}};
244 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
245 const BufferizationOptions &
options,
246 BufferizationState &state)
const {
247 OpBuilder::InsertionGuard g(rewriter);
248 auto ifOp = cast<scf::IfOp>(op);
251 SmallVector<Type> newTypes;
252 for (Value
result : ifOp.getResults()) {
253 if (!isa<TensorLikeType>(
result.getType())) {
254 newTypes.push_back(
result.getType());
257 auto bufferType = bufferization::getBufferType(
result,
options, state);
260 newTypes.push_back(*bufferType);
265 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
270 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
271 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
274 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
279 FailureOr<BufferLikeType>
281 const BufferizationState &state,
282 SmallVector<Value> &invocationStack)
const {
283 auto ifOp = cast<scf::IfOp>(op);
284 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
285 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
289 auto opResult = cast<OpResult>(value);
290 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
291 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
292 BufferLikeType thenBufferType, elseBufferType;
293 if (isa<BufferLikeType>(thenValue.getType())) {
295 thenBufferType = cast<BufferLikeType>(thenValue.getType());
297 auto maybeBufferType = bufferization::getBufferType(
298 thenValue,
options, state, invocationStack);
299 if (
failed(maybeBufferType))
301 thenBufferType = *maybeBufferType;
303 if (isa<BufferLikeType>(elseValue.getType())) {
305 elseBufferType = cast<BufferLikeType>(elseValue.getType());
307 auto maybeBufferType = bufferization::getBufferType(
308 elseValue,
options, state, invocationStack);
309 if (
failed(maybeBufferType))
311 elseBufferType = *maybeBufferType;
315 if (thenBufferType == elseBufferType)
316 return cast<BufferLikeType>(thenBufferType);
319 auto thenBaseMemRefType = dyn_cast<BaseMemRefType>(thenBufferType);
320 auto elseBaseMemRefType = dyn_cast<BaseMemRefType>(elseBufferType);
321 if (thenBaseMemRefType && elseBaseMemRefType &&
322 thenBaseMemRefType.getMemorySpace() !=
323 elseBaseMemRefType.getMemorySpace())
324 return op->
emitError(
"inconsistent memory space on then/else branches");
329 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
330 cast<TensorType>(opResult.getType()),
331 thenBaseMemRefType.getMemorySpace()));
337struct IndexSwitchOpInterface
338 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
339 scf::IndexSwitchOp> {
340 AliasingOpOperandList
341 getAliasingOpOperands(Operation *op, Value value,
342 const AnalysisState &state)
const {
345 auto switchOp = cast<scf::IndexSwitchOp>(op);
346 int64_t resultNum = cast<OpResult>(value).getResultNumber();
347 AliasingOpOperandList
result;
348 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
350 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
351 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
352 BufferRelation::Equivalent,
355 auto defaultYieldOp =
356 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
357 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
358 BufferRelation::Equivalent,
363 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
364 const BufferizationOptions &
options,
365 BufferizationState &state)
const {
366 OpBuilder::InsertionGuard g(rewriter);
367 auto switchOp = cast<scf::IndexSwitchOp>(op);
370 SmallVector<Type> newTypes;
371 for (Value
result : switchOp.getResults()) {
372 if (!isa<TensorType>(
result.getType())) {
373 newTypes.push_back(
result.getType());
376 auto bufferType = bufferization::getBufferType(
result,
options, state);
379 newTypes.push_back(*bufferType);
384 auto newSwitchOp = scf::IndexSwitchOp::create(
385 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
386 switchOp.getCases(), switchOp.getCases().size());
389 for (
auto [src, dest] :
390 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
393 newSwitchOp.getDefaultRegion(),
394 newSwitchOp.getDefaultRegion().begin());
397 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
402 FailureOr<BufferLikeType>
404 const BufferizationState &state,
405 SmallVector<Value> &invocationStack)
const {
406 auto switchOp = cast<scf::IndexSwitchOp>(op);
408 int64_t resultNum = cast<OpResult>(value).getResultNumber();
412 auto getYieldedBufferType = [&](
Block &
b) -> FailureOr<BaseMemRefType> {
413 auto yieldOp = cast<scf::YieldOp>(
b.getTerminator());
414 Value yieldedValue = yieldOp->getOperand(resultNum);
415 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
417 auto maybeBufferType = bufferization::getBufferType(
418 yieldedValue,
options, state, invocationStack);
419 return bufferization::detail::asMemRefType(maybeBufferType);
423 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
424 if (
failed(maybeBufferType))
426 BaseMemRefType bufferType = *maybeBufferType;
429 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
430 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
431 if (
failed(yieldedBufferType))
435 if (bufferType == *yieldedBufferType)
439 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
440 return op->
emitError(
"inconsistent memory space on switch cases");
445 bufferType = getMemRefTypeWithFullyDynamicLayout(
449 return cast<BufferLikeType>(bufferType);
457 for (
const auto &it : llvm::enumerate(values))
458 if (isa<TensorLikeType>(it.value().getType()))
459 result.insert(it.index());
467 const AnalysisState &state) {
468 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
470 for (
unsigned int i = 0; i < minSize; ++i) {
471 if (!isa<TensorLikeType>(bbArgs[i].
getType()) ||
472 !isa<TensorLikeType>(yieldedValues[i].
getType()))
474 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
482static FailureOr<SmallVector<Value>>
483getBuffers(RewriterBase &rewriter,
const MutableOperandRange &operands,
484 const BufferizationOptions &
options, BufferizationState &state) {
485 SmallVector<Value>
result;
486 for (OpOperand &opOperand : operands) {
487 if (isa<TensorLikeType>(opOperand.get().getType())) {
488 FailureOr<Value> resultBuffer =
489 getBuffer(rewriter, opOperand.get(),
options, state);
492 result.push_back(*resultBuffer);
494 result.push_back(opOperand.get());
503static SmallVector<Value>
507 SmallVector<Value>
result;
508 for (
const auto &it : llvm::enumerate(bbArgs)) {
509 size_t idx = it.index();
510 Value val = it.value();
511 if (tensorIndices.contains(idx)) {
513 bufferization::ToTensorOp::create(rewriter, val.
getLoc(),
514 oldBbArgs[idx].getType(), val)
535static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
536 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
537 const BufferizationOptions &
options,
const BufferizationState &state,
538 SmallVector<Value> &invocationStack) {
540 auto initArgBufferType =
541 bufferization::getBufferType(initArg,
options, state, invocationStack);
542 if (
failed(initArgBufferType))
545 if (llvm::count(invocationStack, iterArg) >= 2) {
556 return *initArgBufferType;
560 BufferLikeType yieldedValueBufferType;
561 if (isa<BufferLikeType>(yieldedValue.
getType())) {
563 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.
getType());
567 auto maybeBufferType = bufferization::getBufferType(yieldedValue,
options,
568 state, invocationStack);
569 if (
failed(maybeBufferType))
571 yieldedValueBufferType = *maybeBufferType;
575 if (*initArgBufferType == yieldedValueBufferType)
576 return yieldedValueBufferType;
581 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
582 auto iterTensorType = cast<TensorType>(iterArg.
getType());
583 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
584 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
586 "init_arg and yielded value bufferize to inconsistent memory spaces");
588 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
590 llvm::all_equal({yieldedRankedBufferType.getShape(),
591 cast<MemRefType>(initBufferType).getShape(),
592 cast<RankedTensorType>(iterTensorType).getShape()}) &&
593 "expected same shape");
598 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
599 iterTensorType, yieldedBufferType.getMemorySpace()));
603bool mayHaveZeroIterations(scf::ForOp forOp) {
606 if (!lb.has_value() || !ub.has_value())
614 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
616 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
617 const AnalysisState &state)
const {
618 auto forOp = cast<scf::ForOp>(op);
622 if (mayHaveZeroIterations(forOp))
627 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
630 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
631 const AnalysisState &state)
const {
636 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
637 const AnalysisState &state)
const {
638 auto forOp = cast<scf::ForOp>(op);
639 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
640 BufferRelation relation = bufferRelation(op, opResult, state);
641 return {{opResult, relation,
642 relation == BufferRelation::Equivalent}};
645 BufferRelation bufferRelation(Operation *op, OpResult opResult,
646 const AnalysisState &state)
const {
649 auto forOp = cast<scf::ForOp>(op);
650 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
651 bool equivalentYield = state.areEquivalentBufferizedValues(
652 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
653 return equivalentYield ? BufferRelation::Equivalent
654 : BufferRelation::Unknown;
657 bool isWritable(Operation *op, Value value,
658 const AnalysisState &state)
const {
669 resolveConflicts(Operation *op, RewriterBase &rewriter,
670 const AnalysisState &analysisState,
671 const BufferizationState &bufferizationState)
const {
672 auto bufferizableOp = cast<BufferizableOpInterface>(op);
673 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
674 rewriter, analysisState, bufferizationState)))
677 if (analysisState.getOptions().copyBeforeWrite)
685 auto forOp = cast<scf::ForOp>(op);
686 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
687 OpBuilder::InsertionGuard g(rewriter);
695 SmallVector<Value> yieldValues;
696 for (
const auto it : llvm::enumerate(yieldOp.getResults())) {
701 if (!
indices.contains(it.index()) ||
702 doesNotAliasExternalValue(
703 it.value(), &forOp.getRegion(),
704 forOp.getRegionIterArg(it.index()),
705 static_cast<const OneShotAnalysisState &
>(analysisState))) {
706 yieldValues.push_back(it.value());
709 FailureOr<Value> alloc = allocateTensorForShapedValue(
710 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
714 yieldValues.push_back(*alloc);
718 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
722 FailureOr<BufferLikeType>
724 const BufferizationState &state,
725 SmallVector<Value> &invocationStack)
const {
726 auto forOp = cast<scf::ForOp>(op);
728 assert(isa<TensorLikeType>(value.
getType()) &&
"expected tensor type");
730 if (
auto opResult = dyn_cast<OpResult>(value)) {
732 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
733 return bufferization::getBufferType(bbArg,
options, state,
738 BlockArgument bbArg = cast<BlockArgument>(value);
739 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
742 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
743 Value yieldedValue = yieldOp.getOperand(resultNum);
744 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
745 Value initArg = forOp.getInitArgs()[resultNum];
746 return computeLoopRegionIterArgBufferType(
747 op, iterArg, initArg, yieldedValue,
options, state, invocationStack);
750 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
751 const BufferizationOptions &
options,
752 BufferizationState &state)
const {
753 auto forOp = cast<scf::ForOp>(op);
754 Block *oldLoopBody = forOp.getBody();
761 FailureOr<SmallVector<Value>> maybeInitArgs =
762 getBuffers(rewriter, forOp.getInitArgsMutable(),
options, state);
763 if (
failed(maybeInitArgs))
765 SmallVector<Value> initArgs = *maybeInitArgs;
768 SmallVector<Value> castedInitArgs;
769 for (
const auto &it : llvm::enumerate(initArgs)) {
770 Value initArg = it.value();
771 Value
result = forOp->getResult(it.index());
773 if (!isa<TensorLikeType>(
result.getType())) {
774 castedInitArgs.push_back(initArg);
777 auto targetType = bufferization::getBufferType(
result,
options, state);
780 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
784 auto newForOp = scf::ForOp::create(
785 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
786 forOp.getStep(), castedInitArgs,
nullptr,
787 forOp.getUnsignedCmp());
788 newForOp->setAttrs(forOp->getAttrs());
789 Block *loopBody = newForOp.getBody();
794 SmallVector<Value> iterArgs =
795 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
796 forOp.getRegionIterArgs(),
indices);
797 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
800 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
803 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
815 LogicalResult verifyAnalysis(Operation *op,
816 const AnalysisState &state)
const {
818 static_cast<const OneShotBufferizationOptions &
>(state.getOptions());
819 if (
options.allowReturnAllocsFromLoops)
822 auto forOp = cast<scf::ForOp>(op);
823 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
825 if (!isa<TensorLikeType>(opResult.
getType()))
829 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
830 return yieldOp->emitError()
832 <<
" is not equivalent to the corresponding iter bbArg";
841struct WhileOpInterface
842 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
844 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
845 const AnalysisState &state)
const {
850 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
851 const AnalysisState &state)
const {
856 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
857 const AnalysisState &state)
const {
858 auto whileOp = cast<scf::WhileOp>(op);
868 OpResult opResult = whileOp->getResult(idx);
869 BufferRelation relation = bufferRelation(op, opResult, state);
870 return {{opResult, relation,
871 relation == BufferRelation::Equivalent}};
874 BufferRelation bufferRelation(Operation *op, OpResult opResult,
875 const AnalysisState &state)
const {
880 auto whileOp = cast<scf::WhileOp>(op);
883 if (resultNumber >= whileOp.getBeforeArguments().size())
884 return BufferRelation::Unknown;
886 whileOp.getBeforeArguments()[resultNumber].getType())
887 return BufferRelation::Unknown;
889 auto conditionOp = whileOp.getConditionOp();
890 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
891 Value conditionOperand = conditionOp.getArgs()[resultNumber];
892 bool equivCondition =
893 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
895 auto yieldOp = whileOp.getYieldOp();
896 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
897 Value yieldOperand = yieldOp.getOperand(resultNumber);
899 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
901 return equivCondition && equivYield ? BufferRelation::Equivalent
902 : BufferRelation::Unknown;
905 bool isWritable(Operation *op, Value value,
906 const AnalysisState &state)
const {
917 resolveConflicts(Operation *op, RewriterBase &rewriter,
918 const AnalysisState &analysisState,
919 const BufferizationState &bufferizationState)
const {
920 auto bufferizableOp = cast<BufferizableOpInterface>(op);
921 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
922 rewriter, analysisState, bufferizationState)))
925 if (analysisState.getOptions().copyBeforeWrite)
935 OpBuilder::InsertionGuard g(rewriter);
936 auto whileOp = cast<scf::WhileOp>(op);
937 auto conditionOp = whileOp.getConditionOp();
942 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
944 getEquivalentBuffers(whileOp.getAfterArguments(),
945 whileOp.getYieldOp().getResults(), analysisState);
949 SmallVector<Value> beforeYieldValues;
950 for (int64_t idx = 0;
951 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
952 Value value = conditionOp.getArgs()[idx];
953 if (!isa<TensorLikeType>(value.
getType()) ||
954 (equivalentYieldsAfter.contains(idx) &&
955 equivalentYieldsBefore.contains(idx))) {
956 beforeYieldValues.push_back(value);
959 FailureOr<Value> alloc = allocateTensorForShapedValue(
960 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
964 beforeYieldValues.push_back(*alloc);
967 conditionOp.getArgsMutable().assign(beforeYieldValues);
973 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
974 const BufferizationOptions &
options,
975 BufferizationState &state)
const {
976 auto whileOp = cast<scf::WhileOp>(op);
982 getTensorIndices(whileOp.getAfterArguments());
985 FailureOr<SmallVector<Value>> maybeInitArgs =
986 getBuffers(rewriter, whileOp.getInitsMutable(),
options, state);
987 if (
failed(maybeInitArgs))
989 SmallVector<Value> initArgs = *maybeInitArgs;
992 SmallVector<Value> castedInitArgs;
993 for (
const auto &it : llvm::enumerate(initArgs)) {
994 Value initArg = it.value();
995 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
997 if (!isa<TensorLikeType>(beforeArg.
getType())) {
998 castedInitArgs.push_back(initArg);
1001 auto targetType = bufferization::getBufferType(beforeArg,
options, state);
1004 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
1008 SmallVector<Type> argsTypesAfter = llvm::map_to_vector(
1009 whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
1010 if (!isa<TensorLikeType>(bbArg.getType()))
1011 return bbArg.getType();
1013 return llvm::cast<Type>(
1014 *bufferization::getBufferType(bbArg, options, state));
1019 TypeRange argsTypesBefore(argsRangeBefore);
1020 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1021 argsTypesAfter, castedInitArgs);
1024 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1026 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1028 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1029 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1030 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1031 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1037 SmallVector<Value> newBeforeArgs =
1038 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1039 whileOp.getBeforeArguments(), indicesBefore);
1040 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1046 SmallVector<Value> newAfterArgs =
1047 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1048 whileOp.getAfterArguments(), indicesAfter);
1049 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1052 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1057 FailureOr<BufferLikeType>
1059 const BufferizationState &state,
1060 SmallVector<Value> &invocationStack)
const {
1061 auto whileOp = cast<scf::WhileOp>(op);
1063 assert(isa<TensorLikeType>(value.
getType()) &&
"expected tensor type");
1066 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1068 Value initArg = whileOp.getInits()[bbArg.
getArgNumber()];
1069 auto yieldOp = whileOp.getYieldOp();
1070 Value yieldedValue = yieldOp.getOperand(bbArg.
getArgNumber());
1071 return computeLoopRegionIterArgBufferType(
1072 op, bbArg, initArg, yieldedValue,
options, state, invocationStack);
1080 if (
auto opResult = dyn_cast<OpResult>(value)) {
1082 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1083 &whileOp.getAfter()) {
1084 resultNum = cast<BlockArgument>(value).getArgNumber();
1086 llvm_unreachable(
"invalid value");
1088 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1089 if (!isa<TensorLikeType>(conditionYieldedVal.
getType())) {
1091 return cast<BufferLikeType>(conditionYieldedVal.
getType());
1093 return bufferization::getBufferType(conditionYieldedVal,
options, state,
1107 LogicalResult verifyAnalysis(Operation *op,
1108 const AnalysisState &state)
const {
1109 auto whileOp = cast<scf::WhileOp>(op);
1111 static_cast<const OneShotBufferizationOptions &
>(state.getOptions());
1112 if (
options.allowReturnAllocsFromLoops)
1115 auto conditionOp = whileOp.getConditionOp();
1116 for (
const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1117 Block *block = conditionOp->getBlock();
1118 if (!isa<TensorLikeType>(it.value().getType()))
1121 !state.areEquivalentBufferizedValues(it.value(),
1123 return conditionOp->emitError()
1124 <<
"Condition arg #" << it.index()
1125 <<
" is not equivalent to the corresponding iter bbArg";
1128 auto yieldOp = whileOp.getYieldOp();
1129 for (
const auto &it : llvm::enumerate(yieldOp.getResults())) {
1130 Block *block = yieldOp->getBlock();
1131 if (!isa<TensorLikeType>(it.value().getType()))
1134 !state.areEquivalentBufferizedValues(it.value(),
1136 return yieldOp->emitError()
1137 <<
"Yield operand #" << it.index()
1138 <<
" is not equivalent to the corresponding iter bbArg";
1147struct YieldOpInterface
1148 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1150 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1151 const AnalysisState &state)
const {
1155 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1156 const AnalysisState &state)
const {
1160 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1161 const AnalysisState &state)
const {
1162 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1164 BufferRelation::Equivalent,
false}};
1168 BufferRelation::Equivalent}};
1172 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1173 const AnalysisState &state)
const {
1180 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1181 const BufferizationOptions &
options,
1182 BufferizationState &state)
const {
1183 auto yieldOp = cast<scf::YieldOp>(op);
1184 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1185 scf::WhileOp>(yieldOp->getParentOp()))
1186 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1188 SmallVector<Value> newResults;
1189 for (
const auto &it : llvm::enumerate(yieldOp.getResults())) {
1190 Value value = it.value();
1191 if (isa<TensorLikeType>(value.
getType())) {
1192 FailureOr<Value> maybeBuffer =
1193 getBuffer(rewriter, value,
options, state);
1196 Value buffer = *maybeBuffer;
1198 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1199 yieldOp->getParentOp())) {
1200 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1201 yieldOp->getParentOp()->getResult(it.index()),
options, state);
1204 buffer = castBuffer(rewriter, buffer, *resultType);
1205 }
else if (
auto whileOp =
1206 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1207 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1208 whileOp.getBeforeArguments()[it.index()],
options, state);
1211 buffer = castBuffer(rewriter, buffer, *resultType);
1213 newResults.push_back(buffer);
1215 newResults.push_back(value);
1219 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1228struct ForallOpInterface
1229 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1231 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1232 const AnalysisState &state)
const {
1240 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1241 const AnalysisState &state)
const {
1246 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1247 const AnalysisState &state)
const {
1248 auto forallOp = cast<ForallOp>(op);
1250 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1253 bool isWritable(Operation *op, Value value,
1254 const AnalysisState &state)
const {
1258 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1259 const BufferizationOptions &
options,
1260 BufferizationState &state)
const {
1261 OpBuilder::InsertionGuard guard(rewriter);
1262 auto forallOp = cast<ForallOp>(op);
1263 int64_t rank = forallOp.getRank();
1266 SmallVector<Value> buffers;
1267 for (Value out : forallOp.getOutputs()) {
1268 FailureOr<Value> buffer = getBuffer(rewriter, out,
options, state);
1271 buffers.push_back(*buffer);
1276 for (
const auto &it : llvm::zip(
1277 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1278 BlockArgument bbArg = std::get<0>(it);
1279 Value buffer = std::get<1>(it);
1280 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1288 ForallOp newForallOp;
1289 newForallOp = ForallOp::create(
1290 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1291 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1297 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1300 SmallVector<Value> replacementBbArgs;
1301 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1302 newForallOp.getBody()->getArguments().end());
1303 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1304 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1308 replaceOpWithBufferizedValues(rewriter, op, buffers);
1313 FailureOr<BufferLikeType>
1315 const BufferizationState &state,
1316 SmallVector<Value> &invocationStack)
const {
1317 auto forallOp = cast<ForallOp>(op);
1319 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1322 return bufferization::getBufferType(
1323 forallOp.getTiedOpOperand(bbArg)->get(),
options, state,
1328 return bufferization::getBufferType(
1329 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1330 state, invocationStack);
1334 auto forallOp = cast<ForallOp>(op);
1338 for (
auto [lb, ub, step] :
1339 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1340 forallOp.getMixedStep())) {
1353 if (*lbConstant + *stepConstant < *ubConstant)
1359 bool isParallelRegion(Operation *op,
unsigned index)
const {
1365struct InParallelOpInterface
1366 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1368 LogicalResult bufferize(Operation *op, RewriterBase &
b,
1369 const BufferizationOptions &
options,
1370 BufferizationState &state)
const {
1371 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1383 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1384 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1385 ForOp::attachInterface<ForOpInterface>(*ctx);
1386 IfOp::attachInterface<IfOpInterface>(*ctx);
1387 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1388 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1389 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1390 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1391 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static Operation * getOwnerOfValue(Value value)
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
bool hasOneBlock()
Return true if this region has exactly one block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet