35 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
43 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return b.
create<memref::CastOp>(buffer.
getLoc(), type, buffer).getResult();
52 static bool doesNotAliasExternalValue(
Value value,
Region *region,
56 "expected region with single block");
58 state.applyOnAliases(value, [&](
Value alias) {
59 if (llvm::is_contained(exceptions, alias))
64 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
71 struct ConditionOpInterface
72 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
99 auto conditionOp = cast<scf::ConditionOp>(op);
100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
104 Value value = it.value();
105 if (isa<TensorType>(value.
getType())) {
107 if (failed(maybeBuffer))
110 whileOp.getAfterArguments()[it.index()],
options);
111 if (failed(resultType))
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
116 newArgs.push_back(value);
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 for (
Block &block : executeRegionOp.getRegion()) {
131 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
142 struct ExecuteRegionOpInterface
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146 static bool supportsUnstructuredControlFlow() {
return true; }
153 LogicalResult verifyAnalysis(
Operation *op,
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->
emitOpError(
"op without unique scf.yield is not supported");
165 if (
auto bbArg = dyn_cast<BlockArgument>(value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
174 assert(it != op->
getOpResults().end() &&
"invalid value");
175 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
185 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186 auto yieldOp = getUniqueYieldOp(executeRegionOp);
187 TypeRange newResultTypes(yieldOp.getResults());
191 rewriter.
create<scf::ExecuteRegionOp>(op->
getLoc(), newResultTypes);
195 for (
Block &block : newOp.getRegion())
203 for (
const auto &it :
llvm::enumerate(executeRegionOp->getResultTypes())) {
204 if (isa<TensorType>(it.value())) {
205 newResults.push_back(rewriter.
create<bufferization::ToTensorOp>(
206 executeRegionOp.getLoc(), it.value(),
207 newOp->getResult(it.index())));
209 newResults.push_back(newOp->getResult(it.index()));
214 rewriter.
replaceOp(executeRegionOp, newResults);
222 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
230 auto ifOp = cast<scf::IfOp>(op);
231 size_t resultNum = std::distance(op->
getOpResults().begin(),
233 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
234 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
242 auto ifOp = cast<scf::IfOp>(op);
246 for (
Value result : ifOp.getResults()) {
247 if (!isa<TensorType>(result.getType())) {
248 newTypes.push_back(result.getType());
252 if (failed(bufferType))
254 newTypes.push_back(*bufferType);
260 rewriter.
create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
264 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
265 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
273 FailureOr<BaseMemRefType>
276 auto ifOp = cast<scf::IfOp>(op);
277 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
278 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
282 auto opResult = cast<OpResult>(value);
283 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
284 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
286 if (isa<BaseMemRefType>(thenValue.getType())) {
288 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
290 auto maybeBufferType =
292 if (failed(maybeBufferType))
294 thenBufferType = *maybeBufferType;
296 if (isa<BaseMemRefType>(elseValue.getType())) {
298 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
300 auto maybeBufferType =
302 if (failed(maybeBufferType))
304 elseBufferType = *maybeBufferType;
308 if (thenBufferType == elseBufferType)
309 return thenBufferType;
313 return op->
emitError(
"inconsistent memory space on then/else branches");
317 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace());
323 struct IndexSwitchOpInterface
324 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
325 scf::IndexSwitchOp> {
331 auto switchOp = cast<scf::IndexSwitchOp>(op);
332 int64_t resultNum = cast<OpResult>(value).getResultNumber();
334 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
336 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
341 auto defaultYieldOp =
342 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
352 auto switchOp = cast<scf::IndexSwitchOp>(op);
356 for (
Value result : switchOp.getResults()) {
357 if (!isa<TensorType>(result.getType())) {
358 newTypes.push_back(result.getType());
362 if (failed(bufferType))
364 newTypes.push_back(*bufferType);
369 auto newSwitchOp = rewriter.
create<scf::IndexSwitchOp>(
370 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
371 switchOp.getCases().size());
374 for (
auto [src, dest] :
375 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
378 newSwitchOp.getDefaultRegion(),
379 newSwitchOp.getDefaultRegion().begin());
387 FailureOr<BaseMemRefType>
390 auto switchOp = cast<scf::IndexSwitchOp>(op);
392 int64_t resultNum = cast<OpResult>(value).getResultNumber();
396 auto getYieldedBufferType = [&](
Block &b) -> FailureOr<BaseMemRefType> {
397 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
398 Value yieldedValue = yieldOp->getOperand(resultNum);
399 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
401 auto maybeBufferType =
403 if (failed(maybeBufferType))
405 return maybeBufferType;
409 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
410 if (failed(maybeBufferType))
415 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
416 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
417 if (failed(yieldedBufferType))
421 if (bufferType == *yieldedBufferType)
425 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
426 return op->
emitError(
"inconsistent memory space on switch cases");
442 if (isa<TensorType>(it.value().getType()))
443 result.insert(it.index());
452 unsigned int minSize =
std::min(bbArgs.size(), yieldedValues.size());
454 for (
unsigned int i = 0; i < minSize; ++i) {
455 if (!isa<TensorType>(bbArgs[i].
getType()) ||
456 !isa<TensorType>(yieldedValues[i].
getType()))
458 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
466 static FailureOr<SmallVector<Value>>
471 if (isa<TensorType>(opOperand.get().getType())) {
472 FailureOr<Value> resultBuffer =
474 if (failed(resultBuffer))
476 result.push_back(*resultBuffer);
478 result.push_back(opOperand.get());
493 size_t idx = it.index();
494 Value val = it.value();
495 if (tensorIndices.contains(idx)) {
496 result.push_back(rewriter
497 .create<bufferization::ToTensorOp>(
498 val.
getLoc(), oldBbArgs[idx].getType(), val)
501 result.push_back(val);
519 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
523 auto initArgBufferType =
525 if (failed(initArgBufferType))
528 if (llvm::count(invocationStack, iterArg) >= 2) {
539 return *initArgBufferType;
544 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
546 yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.
getType());
550 auto maybeBufferType =
552 if (failed(maybeBufferType))
554 yieldedValueBufferType = *maybeBufferType;
558 if (*initArgBufferType == yieldedValueBufferType)
559 return yieldedValueBufferType;
564 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
565 auto iterTensorType = cast<TensorType>(iterArg.
getType());
566 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
567 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
569 "init_arg and yielded value bufferize to inconsistent memory spaces");
571 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
573 llvm::all_equal({yieldedRankedBufferType.getShape(),
574 cast<MemRefType>(initBufferType).
getShape(),
575 cast<RankedTensorType>(iterTensorType).
getShape()}) &&
576 "expected same shape");
580 iterTensorType, yieldedBufferType.getMemorySpace());
584 bool mayHaveZeroIterations(scf::ForOp forOp) {
587 if (!lb.has_value() || !ub.has_value())
594 struct ForOpInterface
595 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
599 auto forOp = cast<scf::ForOp>(op);
603 if (mayHaveZeroIterations(forOp))
608 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
619 auto forOp = cast<scf::ForOp>(op);
620 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
622 return {{opResult, relation,
630 auto forOp = cast<scf::ForOp>(op);
631 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
632 bool equivalentYield = state.areEquivalentBufferizedValues(
633 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
651 auto bufferizableOp = cast<BufferizableOpInterface>(op);
652 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
655 if (!state.getOptions().enforceAliasingInvariants ||
656 state.getOptions().copyBeforeWrite)
664 auto forOp = cast<scf::ForOp>(op);
665 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
680 if (!indices.contains(it.index()) ||
681 doesNotAliasExternalValue(
682 it.value(), &forOp.getRegion(),
683 forOp.getRegionIterArg(it.index()),
685 yieldValues.push_back(it.value());
689 rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
692 yieldValues.push_back(*alloc);
696 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
700 FailureOr<BaseMemRefType>
703 auto forOp = cast<scf::ForOp>(op);
705 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
707 if (
auto opResult = dyn_cast<OpResult>(value)) {
709 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
715 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
718 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
719 Value yieldedValue = yieldOp.getOperand(resultNum);
720 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
721 Value initArg = forOp.getInitArgs()[resultNum];
722 return computeLoopRegionIterArgBufferType(
723 op, iterArg, initArg, yieldedValue,
options, invocationStack);
728 auto forOp = cast<scf::ForOp>(op);
729 Block *oldLoopBody = forOp.getBody();
736 FailureOr<SmallVector<Value>> maybeInitArgs =
737 getBuffers(rewriter, forOp.getInitArgsMutable(),
options);
738 if (failed(maybeInitArgs))
745 Value initArg = it.value();
746 Value result = forOp->getResult(it.index());
748 if (!isa<TensorType>(result.
getType())) {
749 castedInitArgs.push_back(initArg);
753 if (failed(targetType))
755 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
759 auto newForOp = rewriter.
create<scf::ForOp>(
760 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
761 forOp.getStep(), castedInitArgs);
762 newForOp->
setAttrs(forOp->getAttrs());
763 Block *loopBody = newForOp.getBody();
769 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
770 forOp.getRegionIterArgs(), indices);
771 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
774 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
789 LogicalResult verifyAnalysis(
Operation *op,
793 if (
options.allowReturnAllocsFromLoops)
796 auto forOp = cast<scf::ForOp>(op);
797 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
799 if (!isa<TensorType>(opResult.
getType()))
805 return yieldOp->emitError()
807 <<
" is not equivalent to the corresponding iter bbArg";
816 struct WhileOpInterface
817 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
833 auto whileOp = cast<scf::WhileOp>(op);
843 OpResult opResult = whileOp->getResult(idx);
845 return {{opResult, relation,
855 auto whileOp = cast<scf::WhileOp>(op);
858 if (resultNumber >= whileOp.getBeforeArguments().size())
861 whileOp.getBeforeArguments()[resultNumber].getType())
864 auto conditionOp = whileOp.getConditionOp();
865 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
866 Value conditionOperand = conditionOp.getArgs()[resultNumber];
867 bool equivCondition =
868 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
870 auto yieldOp = whileOp.getYieldOp();
871 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
872 Value yieldOperand = yieldOp.getOperand(resultNumber);
874 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
893 auto bufferizableOp = cast<BufferizableOpInterface>(op);
894 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
897 if (!state.getOptions().enforceAliasingInvariants ||
898 state.getOptions().copyBeforeWrite)
909 auto whileOp = cast<scf::WhileOp>(op);
910 auto conditionOp = whileOp.getConditionOp();
915 whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
917 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
922 for (int64_t idx = 0;
923 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
924 Value value = conditionOp.getArgs()[idx];
925 if (!isa<TensorType>(value.
getType()) ||
926 (equivalentYieldsAfter.contains(idx) &&
927 equivalentYieldsBefore.contains(idx))) {
928 beforeYieldValues.push_back(value);
932 rewriter, conditionOp.getLoc(), value, state.getOptions());
935 beforeYieldValues.push_back(*alloc);
938 conditionOp.getArgsMutable().assign(beforeYieldValues);
946 auto whileOp = cast<scf::WhileOp>(op);
952 getTensorIndices(whileOp.getAfterArguments());
955 FailureOr<SmallVector<Value>> maybeInitArgs =
956 getBuffers(rewriter, whileOp.getInitsMutable(),
options);
957 if (failed(maybeInitArgs))
964 Value initArg = it.value();
965 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
967 if (!isa<TensorType>(beforeArg.
getType())) {
968 castedInitArgs.push_back(initArg);
972 if (failed(targetType))
974 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
979 llvm::map_range(whileOp.getAfterArguments(), [&](
BlockArgument bbArg) {
980 if (!isa<TensorType>(bbArg.getType()))
981 return bbArg.getType();
983 return llvm::cast<Type>(
984 *bufferization::getBufferType(bbArg, options));
989 TypeRange argsTypesBefore(argsRangeBefore);
990 auto newWhileOp = rewriter.
create<scf::WhileOp>(
991 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
998 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
999 newWhileOp.getBefore().
addArguments(argsTypesBefore, bbArgLocsBefore);
1000 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1001 newWhileOp.getAfter().
addArguments(argsTypesAfter, bbArgLocsAfter);
1008 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1009 whileOp.getBeforeArguments(), indicesBefore);
1010 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1017 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1018 whileOp.getAfterArguments(), indicesAfter);
1019 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1027 FailureOr<BaseMemRefType>
1030 auto whileOp = cast<scf::WhileOp>(op);
1032 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1035 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1038 auto yieldOp = whileOp.getYieldOp();
1040 return computeLoopRegionIterArgBufferType(
1041 op, bbArg, initArg, yieldedValue,
options, invocationStack);
1049 if (
auto opResult = dyn_cast<OpResult>(value)) {
1051 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1052 &whileOp.getAfter()) {
1053 resultNum = cast<BlockArgument>(value).getArgNumber();
1055 llvm_unreachable(
"invalid value");
1057 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1058 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1060 return cast<BaseMemRefType>(conditionYieldedVal.
getType());
1076 LogicalResult verifyAnalysis(
Operation *op,
1078 auto whileOp = cast<scf::WhileOp>(op);
1081 if (
options.allowReturnAllocsFromLoops)
1084 auto conditionOp = whileOp.getConditionOp();
1086 Block *block = conditionOp->getBlock();
1087 if (!isa<TensorType>(it.value().getType()))
1090 !state.areEquivalentBufferizedValues(it.value(),
1092 return conditionOp->emitError()
1093 <<
"Condition arg #" << it.index()
1094 <<
" is not equivalent to the corresponding iter bbArg";
1097 auto yieldOp = whileOp.getYieldOp();
1099 Block *block = yieldOp->getBlock();
1100 if (!isa<TensorType>(it.value().getType()))
1103 !state.areEquivalentBufferizedValues(it.value(),
1105 return yieldOp->emitError()
1106 <<
"Yield operand #" << it.index()
1107 <<
" is not equivalent to the corresponding iter bbArg";
1116 struct YieldOpInterface
1117 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1131 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1151 auto yieldOp = cast<scf::YieldOp>(op);
1152 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1153 scf::WhileOp>(yieldOp->getParentOp()))
1154 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1158 Value value = it.value();
1159 if (isa<TensorType>(value.
getType())) {
1161 if (failed(maybeBuffer))
1163 Value buffer = *maybeBuffer;
1165 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1166 yieldOp->getParentOp())) {
1168 yieldOp->getParentOp()->getResult(it.index()),
options);
1169 if (failed(resultType))
1171 buffer = castBuffer(rewriter, buffer, *resultType);
1172 }
else if (
auto whileOp =
1173 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1175 whileOp.getBeforeArguments()[it.index()],
options);
1176 if (failed(resultType))
1178 buffer = castBuffer(rewriter, buffer, *resultType);
1180 newResults.push_back(buffer);
1182 newResults.push_back(value);
1186 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1192 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1193 for (
auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1194 forallOp.getMixedUpperBound())) {
1197 if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1207 struct ForallOpInterface
1208 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1212 auto forallOp = cast<ForallOp>(op);
1217 if (mayHaveZeroIterations(forallOp))
1222 return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1233 auto forallOp = cast<ForallOp>(op);
1246 auto forallOp = cast<ForallOp>(op);
1247 int64_t rank = forallOp.getRank();
1251 for (
Value out : forallOp.getOutputs()) {
1255 buffers.push_back(*buffer);
1260 for (
const auto &it : llvm::zip(
1261 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1263 Value buffer = std::get<1>(it);
1264 Value bufferAsTensor = rewriter.
create<ToTensorOp>(
1265 forallOp.getLoc(), bbArg.
getType(), buffer);
1272 ForallOp newForallOp;
1273 newForallOp = rewriter.
create<ForallOp>(
1274 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1275 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1281 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1285 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1286 newForallOp.getBody()->getArguments().end());
1287 replacementBbArgs.append(forallOp.getOutputs().size(),
Value());
1288 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1297 FailureOr<BaseMemRefType>
1300 auto forallOp = cast<ForallOp>(op);
1302 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1306 forallOp.getTiedOpOperand(bbArg)->get(),
options, invocationStack);
1311 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1316 auto forallOp = cast<ForallOp>(op);
1320 for (
auto [lb, ub, step] :
1321 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1322 forallOp.getMixedStep())) {
1335 if (*lbConstant + *stepConstant < *ubConstant)
1341 bool isParallelRegion(
Operation *op,
unsigned index)
const {
1347 struct InParallelOpInterface
1348 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1352 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1364 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1365 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1366 ForOp::attachInterface<ForOpInterface>(*ctx);
1367 IfOp::attachInterface<IfOpInterface>(*ctx);
1368 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1369 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1370 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1371 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1372 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
BlockListType & getBlocks()
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...