22#include "llvm/ADT/SmallVectorExtras.h"
34static Value castBuffer(OpBuilder &
b, Value buffer, Type type) {
35 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
43 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return memref::CastOp::create(
b, buffer.
getLoc(), type, buffer).getResult();
52static bool doesNotAliasExternalValue(Value value, Region *region,
54 const OneShotAnalysisState &state) {
55 assert(region->
hasOneBlock() &&
"expected region with single block");
58 if (llvm::is_contained(exceptions, alias))
63 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
70struct ConditionOpInterface
71 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
74 const AnalysisState &state)
const {
78 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
79 const AnalysisState &state)
const {
83 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
84 const AnalysisState &state)
const {
88 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
89 const AnalysisState &state)
const {
96 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97 const BufferizationOptions &
options,
98 BufferizationState &state)
const {
99 auto conditionOp = cast<scf::ConditionOp>(op);
100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
102 SmallVector<Value> newArgs;
103 for (
const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104 Value value = it.value();
105 if (isa<TensorType>(value.
getType())) {
106 FailureOr<Value> maybeBuffer =
107 getBuffer(rewriter, value,
options, state);
110 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
111 whileOp.getAfterArguments()[it.index()],
options, state);
114 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
115 newArgs.push_back(buffer);
117 newArgs.push_back(value);
121 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
122 rewriter, op, conditionOp.getCondition(), newArgs);
129static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
131 for (
Block &block : executeRegionOp.getRegion()) {
132 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
143struct ExecuteRegionOpInterface
144 :
public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
145 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
147 static bool supportsUnstructuredControlFlow() {
return true; }
149 bool isWritable(Operation *op, Value value,
150 const AnalysisState &state)
const {
154 LogicalResult verifyAnalysis(Operation *op,
155 const AnalysisState &state)
const {
156 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
158 if (!getUniqueYieldOp(executeRegionOp))
159 return op->
emitOpError(
"op without unique scf.yield is not supported");
163 AliasingOpOperandList
164 getAliasingOpOperands(Operation *op, Value value,
165 const AnalysisState &state)
const {
166 if (
auto bbArg = dyn_cast<BlockArgument>(value))
167 return getAliasingBranchOpOperands(op, bbArg, state);
173 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
175 assert(it != op->
getOpResults().end() &&
"invalid value");
176 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
177 auto yieldOp = getUniqueYieldOp(executeRegionOp);
181 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
184 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
185 const BufferizationOptions &
options,
186 BufferizationState &state)
const {
187 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
188 auto yieldOp = getUniqueYieldOp(executeRegionOp);
189 TypeRange newResultTypes(yieldOp.getResults());
192 auto newOp = scf::ExecuteRegionOp::create(
193 rewriter, op->
getLoc(), newResultTypes, executeRegionOp.getNoInline());
194 newOp.getRegion().takeBody(executeRegionOp.getRegion());
197 for (
Block &block : newOp.getRegion())
204 SmallVector<Value> newResults;
205 for (
const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
206 if (isa<TensorType>(it.value())) {
207 newResults.push_back(bufferization::ToTensorOp::create(
208 rewriter, executeRegionOp.getLoc(), it.value(),
209 newOp->getResult(it.index())));
211 newResults.push_back(newOp->getResult(it.index()));
216 rewriter.
replaceOp(executeRegionOp, newResults);
224 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
225 AliasingOpOperandList
226 getAliasingOpOperands(Operation *op, Value value,
227 const AnalysisState &state)
const {
232 auto ifOp = cast<scf::IfOp>(op);
233 size_t resultNum = std::distance(op->
getOpResults().begin(),
235 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
236 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
237 return {{thenOperand, BufferRelation::Equivalent,
false},
238 {elseOperand, BufferRelation::Equivalent,
false}};
241 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242 const BufferizationOptions &
options,
243 BufferizationState &state)
const {
244 OpBuilder::InsertionGuard g(rewriter);
245 auto ifOp = cast<scf::IfOp>(op);
248 SmallVector<Type> newTypes;
249 for (Value
result : ifOp.getResults()) {
250 if (!isa<TensorType>(
result.getType())) {
251 newTypes.push_back(
result.getType());
254 auto bufferType = bufferization::getBufferType(
result,
options, state);
257 newTypes.push_back(*bufferType);
262 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
267 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
268 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
271 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
276 FailureOr<BufferLikeType>
278 const BufferizationState &state,
279 SmallVector<Value> &invocationStack)
const {
280 auto ifOp = cast<scf::IfOp>(op);
281 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
282 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
286 auto opResult = cast<OpResult>(value);
287 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
288 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
289 BaseMemRefType thenBufferType, elseBufferType;
290 if (isa<BaseMemRefType>(thenValue.getType())) {
292 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
294 auto maybeBufferType =
295 bufferization::detail::asMemRefType(bufferization::getBufferType(
296 thenValue,
options, state, invocationStack));
297 if (
failed(maybeBufferType))
299 thenBufferType = *maybeBufferType;
301 if (isa<BaseMemRefType>(elseValue.getType())) {
303 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
305 auto maybeBufferType =
306 bufferization::detail::asMemRefType(bufferization::getBufferType(
307 elseValue,
options, state, invocationStack));
308 if (
failed(maybeBufferType))
310 elseBufferType = *maybeBufferType;
314 if (thenBufferType == elseBufferType)
315 return cast<BufferLikeType>(thenBufferType);
319 return op->
emitError(
"inconsistent memory space on then/else branches");
322 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
323 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace()));
329struct IndexSwitchOpInterface
330 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
331 scf::IndexSwitchOp> {
332 AliasingOpOperandList
333 getAliasingOpOperands(Operation *op, Value value,
334 const AnalysisState &state)
const {
337 auto switchOp = cast<scf::IndexSwitchOp>(op);
338 int64_t resultNum = cast<OpResult>(value).getResultNumber();
339 AliasingOpOperandList
result;
340 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
342 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
343 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
344 BufferRelation::Equivalent,
347 auto defaultYieldOp =
348 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
349 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
350 BufferRelation::Equivalent,
355 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
356 const BufferizationOptions &
options,
357 BufferizationState &state)
const {
358 OpBuilder::InsertionGuard g(rewriter);
359 auto switchOp = cast<scf::IndexSwitchOp>(op);
362 SmallVector<Type> newTypes;
363 for (Value
result : switchOp.getResults()) {
364 if (!isa<TensorType>(
result.getType())) {
365 newTypes.push_back(
result.getType());
368 auto bufferType = bufferization::getBufferType(
result,
options, state);
371 newTypes.push_back(*bufferType);
376 auto newSwitchOp = scf::IndexSwitchOp::create(
377 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
378 switchOp.getCases(), switchOp.getCases().size());
381 for (
auto [src, dest] :
382 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
385 newSwitchOp.getDefaultRegion(),
386 newSwitchOp.getDefaultRegion().begin());
389 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
394 FailureOr<BufferLikeType>
396 const BufferizationState &state,
397 SmallVector<Value> &invocationStack)
const {
398 auto switchOp = cast<scf::IndexSwitchOp>(op);
400 int64_t resultNum = cast<OpResult>(value).getResultNumber();
403 auto getYieldedBufferType = [&](
Block &
b) -> FailureOr<BaseMemRefType> {
404 auto yieldOp = cast<scf::YieldOp>(
b.getTerminator());
405 Value yieldedValue = yieldOp->getOperand(resultNum);
406 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
408 auto maybeBufferType = bufferization::getBufferType(
409 yieldedValue,
options, state, invocationStack);
410 return bufferization::detail::asMemRefType(maybeBufferType);
414 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
415 if (
failed(maybeBufferType))
417 BaseMemRefType bufferType = *maybeBufferType;
420 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
421 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
422 if (
failed(yieldedBufferType))
426 if (bufferType == *yieldedBufferType)
430 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
431 return op->
emitError(
"inconsistent memory space on switch cases");
434 bufferType = getMemRefTypeWithFullyDynamicLayout(
438 return cast<BufferLikeType>(bufferType);
446 for (
const auto &it : llvm::enumerate(values))
447 if (isa<TensorType>(it.value().getType()))
448 result.insert(it.index());
456 const AnalysisState &state) {
457 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
459 for (
unsigned int i = 0; i < minSize; ++i) {
460 if (!isa<TensorType>(bbArgs[i].
getType()) ||
461 !isa<TensorType>(yieldedValues[i].
getType()))
463 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
471static FailureOr<SmallVector<Value>>
472getBuffers(RewriterBase &rewriter,
const MutableOperandRange &operands,
473 const BufferizationOptions &
options, BufferizationState &state) {
474 SmallVector<Value>
result;
475 for (OpOperand &opOperand : operands) {
476 if (isa<TensorType>(opOperand.get().getType())) {
477 FailureOr<Value> resultBuffer =
478 getBuffer(rewriter, opOperand.get(),
options, state);
481 result.push_back(*resultBuffer);
483 result.push_back(opOperand.get());
492static SmallVector<Value>
496 SmallVector<Value>
result;
497 for (
const auto &it : llvm::enumerate(bbArgs)) {
498 size_t idx = it.index();
499 Value val = it.value();
500 if (tensorIndices.contains(idx)) {
502 bufferization::ToTensorOp::create(rewriter, val.
getLoc(),
503 oldBbArgs[idx].getType(), val)
524static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
525 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
526 const BufferizationOptions &
options,
const BufferizationState &state,
527 SmallVector<Value> &invocationStack) {
529 auto initArgBufferType =
530 bufferization::getBufferType(initArg,
options, state, invocationStack);
531 if (
failed(initArgBufferType))
534 if (llvm::count(invocationStack, iterArg) >= 2) {
545 return *initArgBufferType;
549 BufferLikeType yieldedValueBufferType;
550 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
552 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.
getType());
556 auto maybeBufferType = bufferization::getBufferType(yieldedValue,
options,
557 state, invocationStack);
558 if (
failed(maybeBufferType))
560 yieldedValueBufferType = *maybeBufferType;
564 if (*initArgBufferType == yieldedValueBufferType)
565 return yieldedValueBufferType;
570 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
571 auto iterTensorType = cast<TensorType>(iterArg.
getType());
572 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
573 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
575 "init_arg and yielded value bufferize to inconsistent memory spaces");
577 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
579 llvm::all_equal({yieldedRankedBufferType.getShape(),
580 cast<MemRefType>(initBufferType).getShape(),
581 cast<RankedTensorType>(iterTensorType).getShape()}) &&
582 "expected same shape");
585 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
586 iterTensorType, yieldedBufferType.getMemorySpace()));
590bool mayHaveZeroIterations(scf::ForOp forOp) {
593 if (!lb.has_value() || !ub.has_value())
601 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
603 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
604 const AnalysisState &state)
const {
605 auto forOp = cast<scf::ForOp>(op);
609 if (mayHaveZeroIterations(forOp))
614 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
617 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
618 const AnalysisState &state)
const {
623 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
624 const AnalysisState &state)
const {
625 auto forOp = cast<scf::ForOp>(op);
626 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
627 BufferRelation relation = bufferRelation(op, opResult, state);
628 return {{opResult, relation,
629 relation == BufferRelation::Equivalent}};
632 BufferRelation bufferRelation(Operation *op, OpResult opResult,
633 const AnalysisState &state)
const {
636 auto forOp = cast<scf::ForOp>(op);
637 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
638 bool equivalentYield = state.areEquivalentBufferizedValues(
639 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
640 return equivalentYield ? BufferRelation::Equivalent
641 : BufferRelation::Unknown;
644 bool isWritable(Operation *op, Value value,
645 const AnalysisState &state)
const {
656 resolveConflicts(Operation *op, RewriterBase &rewriter,
657 const AnalysisState &analysisState,
658 const BufferizationState &bufferizationState)
const {
659 auto bufferizableOp = cast<BufferizableOpInterface>(op);
660 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
661 rewriter, analysisState, bufferizationState)))
664 if (analysisState.getOptions().copyBeforeWrite)
672 auto forOp = cast<scf::ForOp>(op);
673 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
674 OpBuilder::InsertionGuard g(rewriter);
682 SmallVector<Value> yieldValues;
683 for (
const auto it : llvm::enumerate(yieldOp.getResults())) {
688 if (!
indices.contains(it.index()) ||
689 doesNotAliasExternalValue(
690 it.value(), &forOp.getRegion(),
691 forOp.getRegionIterArg(it.index()),
692 static_cast<const OneShotAnalysisState &
>(analysisState))) {
693 yieldValues.push_back(it.value());
696 FailureOr<Value> alloc = allocateTensorForShapedValue(
697 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
701 yieldValues.push_back(*alloc);
705 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
709 FailureOr<BufferLikeType>
711 const BufferizationState &state,
712 SmallVector<Value> &invocationStack)
const {
713 auto forOp = cast<scf::ForOp>(op);
715 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
717 if (
auto opResult = dyn_cast<OpResult>(value)) {
719 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
720 return bufferization::getBufferType(bbArg,
options, state,
725 BlockArgument bbArg = cast<BlockArgument>(value);
726 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
729 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
730 Value yieldedValue = yieldOp.getOperand(resultNum);
731 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
732 Value initArg = forOp.getInitArgs()[resultNum];
733 return computeLoopRegionIterArgBufferType(
734 op, iterArg, initArg, yieldedValue,
options, state, invocationStack);
737 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
738 const BufferizationOptions &
options,
739 BufferizationState &state)
const {
740 auto forOp = cast<scf::ForOp>(op);
741 Block *oldLoopBody = forOp.getBody();
748 FailureOr<SmallVector<Value>> maybeInitArgs =
749 getBuffers(rewriter, forOp.getInitArgsMutable(),
options, state);
750 if (
failed(maybeInitArgs))
752 SmallVector<Value> initArgs = *maybeInitArgs;
755 SmallVector<Value> castedInitArgs;
756 for (
const auto &it : llvm::enumerate(initArgs)) {
757 Value initArg = it.value();
758 Value
result = forOp->getResult(it.index());
760 if (!isa<TensorType>(
result.getType())) {
761 castedInitArgs.push_back(initArg);
764 auto targetType = bufferization::getBufferType(
result,
options, state);
767 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
771 auto newForOp = scf::ForOp::create(
772 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
773 forOp.getStep(), castedInitArgs,
nullptr,
774 forOp.getUnsignedCmp());
775 newForOp->setAttrs(forOp->getAttrs());
776 Block *loopBody = newForOp.getBody();
781 SmallVector<Value> iterArgs =
782 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
783 forOp.getRegionIterArgs(),
indices);
784 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
787 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
790 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
802 LogicalResult verifyAnalysis(Operation *op,
803 const AnalysisState &state)
const {
805 static_cast<const OneShotBufferizationOptions &
>(state.getOptions());
806 if (
options.allowReturnAllocsFromLoops)
809 auto forOp = cast<scf::ForOp>(op);
810 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
812 if (!isa<TensorType>(opResult.
getType()))
817 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
818 return yieldOp->emitError()
820 <<
" is not equivalent to the corresponding iter bbArg";
829struct WhileOpInterface
830 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
832 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
833 const AnalysisState &state)
const {
838 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
839 const AnalysisState &state)
const {
844 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
845 const AnalysisState &state)
const {
846 auto whileOp = cast<scf::WhileOp>(op);
856 OpResult opResult = whileOp->getResult(idx);
857 BufferRelation relation = bufferRelation(op, opResult, state);
858 return {{opResult, relation,
859 relation == BufferRelation::Equivalent}};
862 BufferRelation bufferRelation(Operation *op, OpResult opResult,
863 const AnalysisState &state)
const {
868 auto whileOp = cast<scf::WhileOp>(op);
871 if (resultNumber >= whileOp.getBeforeArguments().size())
872 return BufferRelation::Unknown;
874 whileOp.getBeforeArguments()[resultNumber].getType())
875 return BufferRelation::Unknown;
877 auto conditionOp = whileOp.getConditionOp();
878 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
879 Value conditionOperand = conditionOp.getArgs()[resultNumber];
880 bool equivCondition =
881 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
883 auto yieldOp = whileOp.getYieldOp();
884 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
885 Value yieldOperand = yieldOp.getOperand(resultNumber);
887 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
889 return equivCondition && equivYield ? BufferRelation::Equivalent
890 : BufferRelation::Unknown;
893 bool isWritable(Operation *op, Value value,
894 const AnalysisState &state)
const {
905 resolveConflicts(Operation *op, RewriterBase &rewriter,
906 const AnalysisState &analysisState,
907 const BufferizationState &bufferizationState)
const {
908 auto bufferizableOp = cast<BufferizableOpInterface>(op);
909 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
910 rewriter, analysisState, bufferizationState)))
913 if (analysisState.getOptions().copyBeforeWrite)
923 OpBuilder::InsertionGuard g(rewriter);
924 auto whileOp = cast<scf::WhileOp>(op);
925 auto conditionOp = whileOp.getConditionOp();
930 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
932 getEquivalentBuffers(whileOp.getAfterArguments(),
933 whileOp.getYieldOp().getResults(), analysisState);
937 SmallVector<Value> beforeYieldValues;
938 for (int64_t idx = 0;
939 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
940 Value value = conditionOp.getArgs()[idx];
941 if (!isa<TensorType>(value.
getType()) ||
942 (equivalentYieldsAfter.contains(idx) &&
943 equivalentYieldsBefore.contains(idx))) {
944 beforeYieldValues.push_back(value);
947 FailureOr<Value> alloc = allocateTensorForShapedValue(
948 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
952 beforeYieldValues.push_back(*alloc);
955 conditionOp.getArgsMutable().assign(beforeYieldValues);
961 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
962 const BufferizationOptions &
options,
963 BufferizationState &state)
const {
964 auto whileOp = cast<scf::WhileOp>(op);
970 getTensorIndices(whileOp.getAfterArguments());
973 FailureOr<SmallVector<Value>> maybeInitArgs =
974 getBuffers(rewriter, whileOp.getInitsMutable(),
options, state);
975 if (
failed(maybeInitArgs))
977 SmallVector<Value> initArgs = *maybeInitArgs;
980 SmallVector<Value> castedInitArgs;
981 for (
const auto &it : llvm::enumerate(initArgs)) {
982 Value initArg = it.value();
983 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
985 if (!isa<TensorType>(beforeArg.
getType())) {
986 castedInitArgs.push_back(initArg);
989 auto targetType = bufferization::getBufferType(beforeArg,
options, state);
992 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
996 SmallVector<Type> argsTypesAfter = llvm::map_to_vector(
997 whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
998 if (!isa<TensorType>(bbArg.getType()))
999 return bbArg.getType();
1001 return llvm::cast<Type>(
1002 *bufferization::getBufferType(bbArg, options, state));
1007 TypeRange argsTypesBefore(argsRangeBefore);
1008 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1009 argsTypesAfter, castedInitArgs);
1012 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1014 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1016 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1017 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1018 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1019 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1025 SmallVector<Value> newBeforeArgs =
1026 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1027 whileOp.getBeforeArguments(), indicesBefore);
1028 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1034 SmallVector<Value> newAfterArgs =
1035 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1036 whileOp.getAfterArguments(), indicesAfter);
1037 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1040 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1045 FailureOr<BufferLikeType>
1047 const BufferizationState &state,
1048 SmallVector<Value> &invocationStack)
const {
1049 auto whileOp = cast<scf::WhileOp>(op);
1051 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1054 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1056 Value initArg = whileOp.getInits()[bbArg.
getArgNumber()];
1057 auto yieldOp = whileOp.getYieldOp();
1058 Value yieldedValue = yieldOp.getOperand(bbArg.
getArgNumber());
1059 return computeLoopRegionIterArgBufferType(
1060 op, bbArg, initArg, yieldedValue,
options, state, invocationStack);
1068 if (
auto opResult = dyn_cast<OpResult>(value)) {
1070 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1071 &whileOp.getAfter()) {
1072 resultNum = cast<BlockArgument>(value).getArgNumber();
1074 llvm_unreachable(
"invalid value");
1076 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1077 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1079 return cast<BufferLikeType>(conditionYieldedVal.
getType());
1081 return bufferization::getBufferType(conditionYieldedVal,
options, state,
1095 LogicalResult verifyAnalysis(Operation *op,
1096 const AnalysisState &state)
const {
1097 auto whileOp = cast<scf::WhileOp>(op);
1099 static_cast<const OneShotBufferizationOptions &
>(state.getOptions());
1100 if (
options.allowReturnAllocsFromLoops)
1103 auto conditionOp = whileOp.getConditionOp();
1104 for (
const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1105 Block *block = conditionOp->getBlock();
1106 if (!isa<TensorType>(it.value().getType()))
1109 !state.areEquivalentBufferizedValues(it.value(),
1111 return conditionOp->emitError()
1112 <<
"Condition arg #" << it.index()
1113 <<
" is not equivalent to the corresponding iter bbArg";
1116 auto yieldOp = whileOp.getYieldOp();
1117 for (
const auto &it : llvm::enumerate(yieldOp.getResults())) {
1118 Block *block = yieldOp->getBlock();
1119 if (!isa<TensorType>(it.value().getType()))
1122 !state.areEquivalentBufferizedValues(it.value(),
1124 return yieldOp->emitError()
1125 <<
"Yield operand #" << it.index()
1126 <<
" is not equivalent to the corresponding iter bbArg";
1135struct YieldOpInterface
1136 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1138 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1139 const AnalysisState &state)
const {
1143 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1144 const AnalysisState &state)
const {
1148 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1149 const AnalysisState &state)
const {
1150 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1152 BufferRelation::Equivalent,
false}};
1156 BufferRelation::Equivalent}};
1160 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1161 const AnalysisState &state)
const {
1168 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1169 const BufferizationOptions &
options,
1170 BufferizationState &state)
const {
1171 auto yieldOp = cast<scf::YieldOp>(op);
1172 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1173 scf::WhileOp>(yieldOp->getParentOp()))
1174 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1176 SmallVector<Value> newResults;
1177 for (
const auto &it : llvm::enumerate(yieldOp.getResults())) {
1178 Value value = it.value();
1179 if (isa<TensorType>(value.
getType())) {
1180 FailureOr<Value> maybeBuffer =
1181 getBuffer(rewriter, value,
options, state);
1184 Value buffer = *maybeBuffer;
1186 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1187 yieldOp->getParentOp())) {
1188 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1189 yieldOp->getParentOp()->getResult(it.index()),
options, state);
1192 buffer = castBuffer(rewriter, buffer, *resultType);
1193 }
else if (
auto whileOp =
1194 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1195 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1196 whileOp.getBeforeArguments()[it.index()],
options, state);
1199 buffer = castBuffer(rewriter, buffer, *resultType);
1201 newResults.push_back(buffer);
1203 newResults.push_back(value);
1207 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1216struct ForallOpInterface
1217 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1219 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1220 const AnalysisState &state)
const {
1228 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1229 const AnalysisState &state)
const {
1234 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1235 const AnalysisState &state)
const {
1236 auto forallOp = cast<ForallOp>(op);
1238 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1241 bool isWritable(Operation *op, Value value,
1242 const AnalysisState &state)
const {
1246 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1247 const BufferizationOptions &
options,
1248 BufferizationState &state)
const {
1249 OpBuilder::InsertionGuard guard(rewriter);
1250 auto forallOp = cast<ForallOp>(op);
1251 int64_t rank = forallOp.getRank();
1254 SmallVector<Value> buffers;
1255 for (Value out : forallOp.getOutputs()) {
1256 FailureOr<Value> buffer = getBuffer(rewriter, out,
options, state);
1259 buffers.push_back(*buffer);
1264 for (
const auto &it : llvm::zip(
1265 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1266 BlockArgument bbArg = std::get<0>(it);
1267 Value buffer = std::get<1>(it);
1268 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1276 ForallOp newForallOp;
1277 newForallOp = ForallOp::create(
1278 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1279 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1285 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1288 SmallVector<Value> replacementBbArgs;
1289 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1290 newForallOp.getBody()->getArguments().end());
1291 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1292 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1296 replaceOpWithBufferizedValues(rewriter, op, buffers);
1301 FailureOr<BufferLikeType>
1303 const BufferizationState &state,
1304 SmallVector<Value> &invocationStack)
const {
1305 auto forallOp = cast<ForallOp>(op);
1307 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1310 return bufferization::getBufferType(
1311 forallOp.getTiedOpOperand(bbArg)->get(),
options, state,
1316 return bufferization::getBufferType(
1317 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1318 state, invocationStack);
1322 auto forallOp = cast<ForallOp>(op);
1326 for (
auto [lb, ub, step] :
1327 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1328 forallOp.getMixedStep())) {
1341 if (*lbConstant + *stepConstant < *ubConstant)
1347 bool isParallelRegion(Operation *op,
unsigned index)
const {
1353struct InParallelOpInterface
1354 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1356 LogicalResult bufferize(Operation *op, RewriterBase &
b,
1357 const BufferizationOptions &
options,
1358 BufferizationState &state)
const {
1359 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1371 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1372 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1373 ForOp::attachInterface<ForOpInterface>(*ctx);
1374 IfOp::attachInterface<IfOpInterface>(*ctx);
1375 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1376 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1377 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1378 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1379 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static Operation * getOwnerOfValue(Value value)
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
bool hasOneBlock()
Return true if this region has exactly one block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet