33static Value castBuffer(OpBuilder &
b, Value buffer, Type type) {
34 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
35 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
42 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
43 "scf.while op bufferization: cast incompatible");
44 return memref::CastOp::create(
b, buffer.
getLoc(), type, buffer).getResult();
51static bool doesNotAliasExternalValue(Value value, Region *region,
53 const OneShotAnalysisState &state) {
54 assert(region->
hasOneBlock() &&
"expected region with single block");
57 if (llvm::is_contained(exceptions, alias))
62 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
69struct ConditionOpInterface
70 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
72 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
73 const AnalysisState &state)
const {
77 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
78 const AnalysisState &state)
const {
82 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
83 const AnalysisState &state)
const {
87 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
88 const AnalysisState &state)
const {
95 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
96 const BufferizationOptions &
options,
97 BufferizationState &state)
const {
98 auto conditionOp = cast<scf::ConditionOp>(op);
99 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101 SmallVector<Value> newArgs;
102 for (
const auto &it : llvm::enumerate(conditionOp.getArgs())) {
103 Value value = it.value();
104 if (isa<TensorType>(value.
getType())) {
105 FailureOr<Value> maybeBuffer =
106 getBuffer(rewriter, value,
options, state);
109 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
110 whileOp.getAfterArguments()[it.index()],
options, state);
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
116 newArgs.push_back(value);
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
128static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 for (
Block &block : executeRegionOp.getRegion()) {
131 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
142struct ExecuteRegionOpInterface
143 :
public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146 static bool supportsUnstructuredControlFlow() {
return true; }
148 bool isWritable(Operation *op, Value value,
149 const AnalysisState &state)
const {
153 LogicalResult verifyAnalysis(Operation *op,
154 const AnalysisState &state)
const {
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->
emitOpError(
"op without unique scf.yield is not supported");
162 AliasingOpOperandList
163 getAliasingOpOperands(Operation *op, Value value,
164 const AnalysisState &state)
const {
165 if (
auto bbArg = dyn_cast<BlockArgument>(value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
174 assert(it != op->
getOpResults().end() &&
"invalid value");
175 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
180 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
183 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184 const BufferizationOptions &
options,
185 BufferizationState &state)
const {
186 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
187 auto yieldOp = getUniqueYieldOp(executeRegionOp);
188 TypeRange newResultTypes(yieldOp.getResults());
191 auto newOp = scf::ExecuteRegionOp::create(
192 rewriter, op->
getLoc(), newResultTypes, executeRegionOp.getNoInline());
193 newOp.getRegion().takeBody(executeRegionOp.getRegion());
196 for (
Block &block : newOp.getRegion())
203 SmallVector<Value> newResults;
204 for (
const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
205 if (isa<TensorType>(it.value())) {
206 newResults.push_back(bufferization::ToTensorOp::create(
207 rewriter, executeRegionOp.getLoc(), it.value(),
208 newOp->getResult(it.index())));
210 newResults.push_back(newOp->getResult(it.index()));
215 rewriter.
replaceOp(executeRegionOp, newResults);
223 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
224 AliasingOpOperandList
225 getAliasingOpOperands(Operation *op, Value value,
226 const AnalysisState &state)
const {
231 auto ifOp = cast<scf::IfOp>(op);
232 size_t resultNum = std::distance(op->
getOpResults().begin(),
234 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
235 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
236 return {{thenOperand, BufferRelation::Equivalent,
false},
237 {elseOperand, BufferRelation::Equivalent,
false}};
240 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
241 const BufferizationOptions &
options,
242 BufferizationState &state)
const {
243 OpBuilder::InsertionGuard g(rewriter);
244 auto ifOp = cast<scf::IfOp>(op);
247 SmallVector<Type> newTypes;
248 for (Value
result : ifOp.getResults()) {
249 if (!isa<TensorType>(
result.getType())) {
250 newTypes.push_back(
result.getType());
253 auto bufferType = bufferization::getBufferType(
result,
options, state);
256 newTypes.push_back(*bufferType);
261 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
266 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
267 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
270 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
275 FailureOr<BufferLikeType>
277 const BufferizationState &state,
278 SmallVector<Value> &invocationStack)
const {
279 auto ifOp = cast<scf::IfOp>(op);
280 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
281 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
285 auto opResult = cast<OpResult>(value);
286 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
287 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
288 BaseMemRefType thenBufferType, elseBufferType;
289 if (isa<BaseMemRefType>(thenValue.getType())) {
291 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
293 auto maybeBufferType =
294 bufferization::detail::asMemRefType(bufferization::getBufferType(
295 thenValue,
options, state, invocationStack));
296 if (
failed(maybeBufferType))
298 thenBufferType = *maybeBufferType;
300 if (isa<BaseMemRefType>(elseValue.getType())) {
302 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
304 auto maybeBufferType =
305 bufferization::detail::asMemRefType(bufferization::getBufferType(
306 elseValue,
options, state, invocationStack));
307 if (
failed(maybeBufferType))
309 elseBufferType = *maybeBufferType;
313 if (thenBufferType == elseBufferType)
314 return cast<BufferLikeType>(thenBufferType);
318 return op->
emitError(
"inconsistent memory space on then/else branches");
321 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
322 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace()));
328struct IndexSwitchOpInterface
329 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330 scf::IndexSwitchOp> {
331 AliasingOpOperandList
332 getAliasingOpOperands(Operation *op, Value value,
333 const AnalysisState &state)
const {
336 auto switchOp = cast<scf::IndexSwitchOp>(op);
337 int64_t resultNum = cast<OpResult>(value).getResultNumber();
338 AliasingOpOperandList
result;
339 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
341 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
342 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
343 BufferRelation::Equivalent,
346 auto defaultYieldOp =
347 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
348 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
349 BufferRelation::Equivalent,
354 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
355 const BufferizationOptions &
options,
356 BufferizationState &state)
const {
357 OpBuilder::InsertionGuard g(rewriter);
358 auto switchOp = cast<scf::IndexSwitchOp>(op);
361 SmallVector<Type> newTypes;
362 for (Value
result : switchOp.getResults()) {
363 if (!isa<TensorType>(
result.getType())) {
364 newTypes.push_back(
result.getType());
367 auto bufferType = bufferization::getBufferType(
result,
options, state);
370 newTypes.push_back(*bufferType);
375 auto newSwitchOp = scf::IndexSwitchOp::create(
376 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
377 switchOp.getCases(), switchOp.getCases().size());
380 for (
auto [src, dest] :
381 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
384 newSwitchOp.getDefaultRegion(),
385 newSwitchOp.getDefaultRegion().begin());
388 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
393 FailureOr<BufferLikeType>
395 const BufferizationState &state,
396 SmallVector<Value> &invocationStack)
const {
397 auto switchOp = cast<scf::IndexSwitchOp>(op);
399 int64_t resultNum = cast<OpResult>(value).getResultNumber();
402 auto getYieldedBufferType = [&](
Block &
b) -> FailureOr<BaseMemRefType> {
403 auto yieldOp = cast<scf::YieldOp>(
b.getTerminator());
404 Value yieldedValue = yieldOp->getOperand(resultNum);
405 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
407 auto maybeBufferType = bufferization::getBufferType(
408 yieldedValue,
options, state, invocationStack);
409 return bufferization::detail::asMemRefType(maybeBufferType);
413 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
414 if (
failed(maybeBufferType))
416 BaseMemRefType bufferType = *maybeBufferType;
419 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
420 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
421 if (
failed(yieldedBufferType))
425 if (bufferType == *yieldedBufferType)
429 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
430 return op->
emitError(
"inconsistent memory space on switch cases");
433 bufferType = getMemRefTypeWithFullyDynamicLayout(
437 return cast<BufferLikeType>(bufferType);
445 for (
const auto &it : llvm::enumerate(values))
446 if (isa<TensorType>(it.value().getType()))
447 result.insert(it.index());
455 const AnalysisState &state) {
456 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
458 for (
unsigned int i = 0; i < minSize; ++i) {
459 if (!isa<TensorType>(bbArgs[i].
getType()) ||
460 !isa<TensorType>(yieldedValues[i].
getType()))
462 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
470static FailureOr<SmallVector<Value>>
471getBuffers(RewriterBase &rewriter,
const MutableOperandRange &operands,
472 const BufferizationOptions &
options, BufferizationState &state) {
473 SmallVector<Value>
result;
474 for (OpOperand &opOperand : operands) {
475 if (isa<TensorType>(opOperand.get().getType())) {
476 FailureOr<Value> resultBuffer =
477 getBuffer(rewriter, opOperand.get(),
options, state);
480 result.push_back(*resultBuffer);
482 result.push_back(opOperand.get());
491static SmallVector<Value>
495 SmallVector<Value>
result;
496 for (
const auto &it : llvm::enumerate(bbArgs)) {
497 size_t idx = it.index();
498 Value val = it.value();
499 if (tensorIndices.contains(idx)) {
501 bufferization::ToTensorOp::create(rewriter, val.
getLoc(),
502 oldBbArgs[idx].getType(), val)
523static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
524 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
525 const BufferizationOptions &
options,
const BufferizationState &state,
526 SmallVector<Value> &invocationStack) {
528 auto initArgBufferType =
529 bufferization::getBufferType(initArg,
options, state, invocationStack);
530 if (
failed(initArgBufferType))
533 if (llvm::count(invocationStack, iterArg) >= 2) {
544 return *initArgBufferType;
548 BufferLikeType yieldedValueBufferType;
549 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
551 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.
getType());
555 auto maybeBufferType = bufferization::getBufferType(yieldedValue,
options,
556 state, invocationStack);
557 if (
failed(maybeBufferType))
559 yieldedValueBufferType = *maybeBufferType;
563 if (*initArgBufferType == yieldedValueBufferType)
564 return yieldedValueBufferType;
569 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
570 auto iterTensorType = cast<TensorType>(iterArg.
getType());
571 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
572 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
574 "init_arg and yielded value bufferize to inconsistent memory spaces");
576 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
578 llvm::all_equal({yieldedRankedBufferType.getShape(),
579 cast<MemRefType>(initBufferType).getShape(),
580 cast<RankedTensorType>(iterTensorType).getShape()}) &&
581 "expected same shape");
584 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
585 iterTensorType, yieldedBufferType.getMemorySpace()));
589bool mayHaveZeroIterations(scf::ForOp forOp) {
592 if (!lb.has_value() || !ub.has_value())
600 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
602 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
603 const AnalysisState &state)
const {
604 auto forOp = cast<scf::ForOp>(op);
608 if (mayHaveZeroIterations(forOp))
613 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
616 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
617 const AnalysisState &state)
const {
622 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
623 const AnalysisState &state)
const {
624 auto forOp = cast<scf::ForOp>(op);
625 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
626 BufferRelation relation = bufferRelation(op, opResult, state);
627 return {{opResult, relation,
628 relation == BufferRelation::Equivalent}};
631 BufferRelation bufferRelation(Operation *op, OpResult opResult,
632 const AnalysisState &state)
const {
635 auto forOp = cast<scf::ForOp>(op);
636 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
637 bool equivalentYield = state.areEquivalentBufferizedValues(
638 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
639 return equivalentYield ? BufferRelation::Equivalent
640 : BufferRelation::Unknown;
643 bool isWritable(Operation *op, Value value,
644 const AnalysisState &state)
const {
655 resolveConflicts(Operation *op, RewriterBase &rewriter,
656 const AnalysisState &analysisState,
657 const BufferizationState &bufferizationState)
const {
658 auto bufferizableOp = cast<BufferizableOpInterface>(op);
659 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
660 rewriter, analysisState, bufferizationState)))
663 if (analysisState.getOptions().copyBeforeWrite)
671 auto forOp = cast<scf::ForOp>(op);
672 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
673 OpBuilder::InsertionGuard g(rewriter);
681 SmallVector<Value> yieldValues;
682 for (
const auto it : llvm::enumerate(yieldOp.getResults())) {
687 if (!
indices.contains(it.index()) ||
688 doesNotAliasExternalValue(
689 it.value(), &forOp.getRegion(),
690 forOp.getRegionIterArg(it.index()),
691 static_cast<const OneShotAnalysisState &
>(analysisState))) {
692 yieldValues.push_back(it.value());
695 FailureOr<Value> alloc = allocateTensorForShapedValue(
696 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
700 yieldValues.push_back(*alloc);
704 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
708 FailureOr<BufferLikeType>
710 const BufferizationState &state,
711 SmallVector<Value> &invocationStack)
const {
712 auto forOp = cast<scf::ForOp>(op);
714 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
716 if (
auto opResult = dyn_cast<OpResult>(value)) {
718 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
719 return bufferization::getBufferType(bbArg,
options, state,
724 BlockArgument bbArg = cast<BlockArgument>(value);
725 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
728 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
729 Value yieldedValue = yieldOp.getOperand(resultNum);
730 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
731 Value initArg = forOp.getInitArgs()[resultNum];
732 return computeLoopRegionIterArgBufferType(
733 op, iterArg, initArg, yieldedValue,
options, state, invocationStack);
736 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
737 const BufferizationOptions &
options,
738 BufferizationState &state)
const {
739 auto forOp = cast<scf::ForOp>(op);
740 Block *oldLoopBody = forOp.getBody();
747 FailureOr<SmallVector<Value>> maybeInitArgs =
748 getBuffers(rewriter, forOp.getInitArgsMutable(),
options, state);
749 if (
failed(maybeInitArgs))
751 SmallVector<Value> initArgs = *maybeInitArgs;
754 SmallVector<Value> castedInitArgs;
755 for (
const auto &it : llvm::enumerate(initArgs)) {
756 Value initArg = it.value();
757 Value
result = forOp->getResult(it.index());
759 if (!isa<TensorType>(
result.getType())) {
760 castedInitArgs.push_back(initArg);
763 auto targetType = bufferization::getBufferType(
result,
options, state);
766 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
770 auto newForOp = scf::ForOp::create(
771 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
772 forOp.getStep(), castedInitArgs,
nullptr,
773 forOp.getUnsignedCmp());
774 newForOp->setAttrs(forOp->getAttrs());
775 Block *loopBody = newForOp.getBody();
780 SmallVector<Value> iterArgs =
781 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
782 forOp.getRegionIterArgs(),
indices);
783 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
786 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
789 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
801 LogicalResult verifyAnalysis(Operation *op,
802 const AnalysisState &state)
const {
804 static_cast<const OneShotBufferizationOptions &
>(state.getOptions());
805 if (
options.allowReturnAllocsFromLoops)
808 auto forOp = cast<scf::ForOp>(op);
809 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
811 if (!isa<TensorType>(opResult.
getType()))
816 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
817 return yieldOp->emitError()
819 <<
" is not equivalent to the corresponding iter bbArg";
828struct WhileOpInterface
829 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
831 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
832 const AnalysisState &state)
const {
837 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
838 const AnalysisState &state)
const {
843 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
844 const AnalysisState &state)
const {
845 auto whileOp = cast<scf::WhileOp>(op);
855 OpResult opResult = whileOp->getResult(idx);
856 BufferRelation relation = bufferRelation(op, opResult, state);
857 return {{opResult, relation,
858 relation == BufferRelation::Equivalent}};
861 BufferRelation bufferRelation(Operation *op, OpResult opResult,
862 const AnalysisState &state)
const {
867 auto whileOp = cast<scf::WhileOp>(op);
870 if (resultNumber >= whileOp.getBeforeArguments().size())
871 return BufferRelation::Unknown;
873 whileOp.getBeforeArguments()[resultNumber].getType())
874 return BufferRelation::Unknown;
876 auto conditionOp = whileOp.getConditionOp();
877 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
878 Value conditionOperand = conditionOp.getArgs()[resultNumber];
879 bool equivCondition =
880 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
882 auto yieldOp = whileOp.getYieldOp();
883 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
884 Value yieldOperand = yieldOp.getOperand(resultNumber);
886 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
888 return equivCondition && equivYield ? BufferRelation::Equivalent
889 : BufferRelation::Unknown;
892 bool isWritable(Operation *op, Value value,
893 const AnalysisState &state)
const {
904 resolveConflicts(Operation *op, RewriterBase &rewriter,
905 const AnalysisState &analysisState,
906 const BufferizationState &bufferizationState)
const {
907 auto bufferizableOp = cast<BufferizableOpInterface>(op);
908 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
909 rewriter, analysisState, bufferizationState)))
912 if (analysisState.getOptions().copyBeforeWrite)
922 OpBuilder::InsertionGuard g(rewriter);
923 auto whileOp = cast<scf::WhileOp>(op);
924 auto conditionOp = whileOp.getConditionOp();
929 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
931 getEquivalentBuffers(whileOp.getAfterArguments(),
932 whileOp.getYieldOp().getResults(), analysisState);
936 SmallVector<Value> beforeYieldValues;
937 for (int64_t idx = 0;
938 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
939 Value value = conditionOp.getArgs()[idx];
940 if (!isa<TensorType>(value.
getType()) ||
941 (equivalentYieldsAfter.contains(idx) &&
942 equivalentYieldsBefore.contains(idx))) {
943 beforeYieldValues.push_back(value);
946 FailureOr<Value> alloc = allocateTensorForShapedValue(
947 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
951 beforeYieldValues.push_back(*alloc);
954 conditionOp.getArgsMutable().assign(beforeYieldValues);
960 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
961 const BufferizationOptions &
options,
962 BufferizationState &state)
const {
963 auto whileOp = cast<scf::WhileOp>(op);
969 getTensorIndices(whileOp.getAfterArguments());
972 FailureOr<SmallVector<Value>> maybeInitArgs =
973 getBuffers(rewriter, whileOp.getInitsMutable(),
options, state);
974 if (
failed(maybeInitArgs))
976 SmallVector<Value> initArgs = *maybeInitArgs;
979 SmallVector<Value> castedInitArgs;
980 for (
const auto &it : llvm::enumerate(initArgs)) {
981 Value initArg = it.value();
982 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
984 if (!isa<TensorType>(beforeArg.
getType())) {
985 castedInitArgs.push_back(initArg);
988 auto targetType = bufferization::getBufferType(beforeArg,
options, state);
991 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
995 SmallVector<Type> argsTypesAfter = llvm::to_vector(
996 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
997 if (!isa<TensorType>(bbArg.getType()))
998 return bbArg.getType();
1000 return llvm::cast<Type>(
1001 *bufferization::getBufferType(bbArg, options, state));
1006 TypeRange argsTypesBefore(argsRangeBefore);
1007 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1008 argsTypesAfter, castedInitArgs);
1011 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1013 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1015 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1016 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1017 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1018 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1024 SmallVector<Value> newBeforeArgs =
1025 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1026 whileOp.getBeforeArguments(), indicesBefore);
1027 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1033 SmallVector<Value> newAfterArgs =
1034 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1035 whileOp.getAfterArguments(), indicesAfter);
1036 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1039 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1044 FailureOr<BufferLikeType>
1046 const BufferizationState &state,
1047 SmallVector<Value> &invocationStack)
const {
1048 auto whileOp = cast<scf::WhileOp>(op);
1050 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1053 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1055 Value initArg = whileOp.getInits()[bbArg.
getArgNumber()];
1056 auto yieldOp = whileOp.getYieldOp();
1057 Value yieldedValue = yieldOp.getOperand(bbArg.
getArgNumber());
1058 return computeLoopRegionIterArgBufferType(
1059 op, bbArg, initArg, yieldedValue,
options, state, invocationStack);
1067 if (
auto opResult = dyn_cast<OpResult>(value)) {
1069 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1070 &whileOp.getAfter()) {
1071 resultNum = cast<BlockArgument>(value).getArgNumber();
1073 llvm_unreachable(
"invalid value");
1075 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1076 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1078 return cast<BufferLikeType>(conditionYieldedVal.
getType());
1080 return bufferization::getBufferType(conditionYieldedVal,
options, state,
1094 LogicalResult verifyAnalysis(Operation *op,
1095 const AnalysisState &state)
const {
1096 auto whileOp = cast<scf::WhileOp>(op);
1098 static_cast<const OneShotBufferizationOptions &
>(state.getOptions());
1099 if (
options.allowReturnAllocsFromLoops)
1102 auto conditionOp = whileOp.getConditionOp();
1103 for (
const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1104 Block *block = conditionOp->getBlock();
1105 if (!isa<TensorType>(it.value().getType()))
1108 !state.areEquivalentBufferizedValues(it.value(),
1110 return conditionOp->emitError()
1111 <<
"Condition arg #" << it.index()
1112 <<
" is not equivalent to the corresponding iter bbArg";
1115 auto yieldOp = whileOp.getYieldOp();
1116 for (
const auto &it : llvm::enumerate(yieldOp.getResults())) {
1117 Block *block = yieldOp->getBlock();
1118 if (!isa<TensorType>(it.value().getType()))
1121 !state.areEquivalentBufferizedValues(it.value(),
1123 return yieldOp->emitError()
1124 <<
"Yield operand #" << it.index()
1125 <<
" is not equivalent to the corresponding iter bbArg";
1134struct YieldOpInterface
1135 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1137 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1138 const AnalysisState &state)
const {
1142 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1143 const AnalysisState &state)
const {
1147 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1148 const AnalysisState &state)
const {
1149 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1151 BufferRelation::Equivalent,
false}};
1155 BufferRelation::Equivalent}};
1159 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1160 const AnalysisState &state)
const {
1167 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1168 const BufferizationOptions &
options,
1169 BufferizationState &state)
const {
1170 auto yieldOp = cast<scf::YieldOp>(op);
1171 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1172 scf::WhileOp>(yieldOp->getParentOp()))
1173 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1175 SmallVector<Value> newResults;
1176 for (
const auto &it : llvm::enumerate(yieldOp.getResults())) {
1177 Value value = it.value();
1178 if (isa<TensorType>(value.
getType())) {
1179 FailureOr<Value> maybeBuffer =
1180 getBuffer(rewriter, value,
options, state);
1183 Value buffer = *maybeBuffer;
1185 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1186 yieldOp->getParentOp())) {
1187 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1188 yieldOp->getParentOp()->getResult(it.index()),
options, state);
1191 buffer = castBuffer(rewriter, buffer, *resultType);
1192 }
else if (
auto whileOp =
1193 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1194 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1195 whileOp.getBeforeArguments()[it.index()],
options, state);
1198 buffer = castBuffer(rewriter, buffer, *resultType);
1200 newResults.push_back(buffer);
1202 newResults.push_back(value);
1206 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1215struct ForallOpInterface
1216 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1218 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1219 const AnalysisState &state)
const {
1227 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1228 const AnalysisState &state)
const {
1233 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1234 const AnalysisState &state)
const {
1235 auto forallOp = cast<ForallOp>(op);
1237 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1240 bool isWritable(Operation *op, Value value,
1241 const AnalysisState &state)
const {
1245 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1246 const BufferizationOptions &
options,
1247 BufferizationState &state)
const {
1248 OpBuilder::InsertionGuard guard(rewriter);
1249 auto forallOp = cast<ForallOp>(op);
1250 int64_t rank = forallOp.getRank();
1253 SmallVector<Value> buffers;
1254 for (Value out : forallOp.getOutputs()) {
1255 FailureOr<Value> buffer = getBuffer(rewriter, out,
options, state);
1258 buffers.push_back(*buffer);
1263 for (
const auto &it : llvm::zip(
1264 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1265 BlockArgument bbArg = std::get<0>(it);
1266 Value buffer = std::get<1>(it);
1267 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1275 ForallOp newForallOp;
1276 newForallOp = ForallOp::create(
1277 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1278 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1284 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1287 SmallVector<Value> replacementBbArgs;
1288 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1289 newForallOp.getBody()->getArguments().end());
1290 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1291 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1295 replaceOpWithBufferizedValues(rewriter, op, buffers);
1300 FailureOr<BufferLikeType>
1302 const BufferizationState &state,
1303 SmallVector<Value> &invocationStack)
const {
1304 auto forallOp = cast<ForallOp>(op);
1306 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1309 return bufferization::getBufferType(
1310 forallOp.getTiedOpOperand(bbArg)->get(),
options, state,
1315 return bufferization::getBufferType(
1316 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1317 state, invocationStack);
1321 auto forallOp = cast<ForallOp>(op);
1325 for (
auto [lb, ub, step] :
1326 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1327 forallOp.getMixedStep())) {
1340 if (*lbConstant + *stepConstant < *ubConstant)
1346 bool isParallelRegion(Operation *op,
unsigned index)
const {
1352struct InParallelOpInterface
1353 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1355 LogicalResult bufferize(Operation *op, RewriterBase &
b,
1356 const BufferizationOptions &
options,
1357 BufferizationState &state)
const {
1358 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1370 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1371 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1372 ForOp::attachInterface<ForOpInterface>(*ctx);
1373 IfOp::attachInterface<IfOpInterface>(*ctx);
1374 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1375 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1376 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1377 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1378 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static Operation * getOwnerOfValue(Value value)
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
bool hasOneBlock()
Return true if this region has exactly one block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet