35 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
43 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return b.
create<memref::CastOp>(buffer.
getLoc(), type, buffer).getResult();
52 static bool doesNotAliasExternalValue(
Value value,
Region *region,
56 "expected region with single block");
58 state.applyOnAliases(value, [&](
Value alias) {
59 if (llvm::is_contained(exceptions, alias))
64 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
71 struct ConditionOpInterface
72 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
99 auto conditionOp = cast<scf::ConditionOp>(op);
100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
104 Value value = it.value();
105 if (isa<TensorType>(value.
getType())) {
107 if (failed(maybeBuffer))
110 whileOp.getAfterArguments()[it.index()],
options);
111 if (failed(resultType))
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
116 newArgs.push_back(value);
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 for (
Block &block : executeRegionOp.getRegion()) {
131 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
142 struct ExecuteRegionOpInterface
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146 static bool supportsUnstructuredControlFlow() {
return true; }
153 LogicalResult verifyAnalysis(
Operation *op,
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->
emitOpError(
"op without unique scf.yield is not supported");
165 if (
auto bbArg = dyn_cast<BlockArgument>(value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
174 assert(it != op->
getOpResults().end() &&
"invalid value");
175 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
185 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186 auto yieldOp = getUniqueYieldOp(executeRegionOp);
187 TypeRange newResultTypes(yieldOp.getResults());
191 rewriter.
create<scf::ExecuteRegionOp>(op->
getLoc(), newResultTypes);
195 for (
Block &block : newOp.getRegion())
203 for (
const auto &it :
llvm::enumerate(executeRegionOp->getResultTypes())) {
204 if (isa<TensorType>(it.value())) {
205 newResults.push_back(rewriter.
create<bufferization::ToTensorOp>(
206 executeRegionOp.getLoc(), newOp->getResult(it.index())));
208 newResults.push_back(newOp->getResult(it.index()));
213 rewriter.
replaceOp(executeRegionOp, newResults);
221 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
229 auto ifOp = cast<scf::IfOp>(op);
230 size_t resultNum = std::distance(op->
getOpResults().begin(),
232 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
233 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
241 auto ifOp = cast<scf::IfOp>(op);
245 for (
Value result : ifOp.getResults()) {
246 if (!isa<TensorType>(result.getType())) {
247 newTypes.push_back(result.getType());
251 if (failed(bufferType))
253 newTypes.push_back(*bufferType);
259 rewriter.
create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
263 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
264 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
272 FailureOr<BaseMemRefType>
275 auto ifOp = cast<scf::IfOp>(op);
276 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
277 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
281 auto opResult = cast<OpResult>(value);
282 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
283 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
285 if (isa<BaseMemRefType>(thenValue.getType())) {
287 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
289 auto maybeBufferType =
291 if (failed(maybeBufferType))
293 thenBufferType = *maybeBufferType;
295 if (isa<BaseMemRefType>(elseValue.getType())) {
297 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
299 auto maybeBufferType =
301 if (failed(maybeBufferType))
303 elseBufferType = *maybeBufferType;
307 if (thenBufferType == elseBufferType)
308 return thenBufferType;
312 return op->
emitError(
"inconsistent memory space on then/else branches");
316 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace());
322 struct IndexSwitchOpInterface
323 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
324 scf::IndexSwitchOp> {
330 auto switchOp = cast<scf::IndexSwitchOp>(op);
331 int64_t resultNum = cast<OpResult>(value).getResultNumber();
333 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
335 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
340 auto defaultYieldOp =
341 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
351 auto switchOp = cast<scf::IndexSwitchOp>(op);
355 for (
Value result : switchOp.getResults()) {
356 if (!isa<TensorType>(result.getType())) {
357 newTypes.push_back(result.getType());
361 if (failed(bufferType))
363 newTypes.push_back(*bufferType);
368 auto newSwitchOp = rewriter.
create<scf::IndexSwitchOp>(
369 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
370 switchOp.getCases().size());
373 for (
auto [src, dest] :
374 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
377 newSwitchOp.getDefaultRegion(),
378 newSwitchOp.getDefaultRegion().begin());
386 FailureOr<BaseMemRefType>
389 auto switchOp = cast<scf::IndexSwitchOp>(op);
391 int64_t resultNum = cast<OpResult>(value).getResultNumber();
395 auto getYieldedBufferType = [&](
Block &b) -> FailureOr<BaseMemRefType> {
396 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
397 Value yieldedValue = yieldOp->getOperand(resultNum);
398 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
400 auto maybeBufferType =
402 if (failed(maybeBufferType))
404 return maybeBufferType;
408 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
409 if (failed(maybeBufferType))
414 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
415 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
416 if (failed(yieldedBufferType))
420 if (bufferType == *yieldedBufferType)
424 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
425 return op->
emitError(
"inconsistent memory space on switch cases");
441 if (isa<TensorType>(it.value().getType()))
442 result.insert(it.index());
451 unsigned int minSize =
std::min(bbArgs.size(), yieldedValues.size());
453 for (
unsigned int i = 0; i < minSize; ++i) {
454 if (!isa<TensorType>(bbArgs[i].
getType()) ||
455 !isa<TensorType>(yieldedValues[i].
getType()))
457 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
465 static FailureOr<SmallVector<Value>>
470 if (isa<TensorType>(opOperand.get().getType())) {
471 FailureOr<Value> resultBuffer =
473 if (failed(resultBuffer))
475 result.push_back(*resultBuffer);
477 result.push_back(opOperand.get());
491 size_t idx = it.index();
492 Value val = it.value();
493 if (tensorIndices.contains(idx)) {
495 rewriter.
create<bufferization::ToTensorOp>(val.
getLoc(), val)
498 result.push_back(val);
516 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
520 auto initArgBufferType =
522 if (failed(initArgBufferType))
525 if (llvm::count(invocationStack, iterArg) >= 2) {
536 return *initArgBufferType;
541 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
543 yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.
getType());
547 auto maybeBufferType =
549 if (failed(maybeBufferType))
551 yieldedValueBufferType = *maybeBufferType;
555 if (*initArgBufferType == yieldedValueBufferType)
556 return yieldedValueBufferType;
561 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
562 auto iterTensorType = cast<TensorType>(iterArg.
getType());
563 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
564 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
566 "init_arg and yielded value bufferize to inconsistent memory spaces");
568 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
570 llvm::all_equal({yieldedRankedBufferType.getShape(),
571 cast<MemRefType>(initBufferType).
getShape(),
572 cast<RankedTensorType>(iterTensorType).
getShape()}) &&
573 "expected same shape");
577 iterTensorType, yieldedBufferType.getMemorySpace());
581 bool mayHaveZeroIterations(scf::ForOp forOp) {
584 if (!lb.has_value() || !ub.has_value())
591 struct ForOpInterface
592 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
596 auto forOp = cast<scf::ForOp>(op);
600 if (mayHaveZeroIterations(forOp))
605 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
616 auto forOp = cast<scf::ForOp>(op);
617 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
619 return {{opResult, relation,
627 auto forOp = cast<scf::ForOp>(op);
628 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
629 bool equivalentYield = state.areEquivalentBufferizedValues(
630 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
648 auto bufferizableOp = cast<BufferizableOpInterface>(op);
649 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
652 if (!state.getOptions().enforceAliasingInvariants)
660 auto forOp = cast<scf::ForOp>(op);
661 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
676 if (!indices.contains(it.index()) ||
677 doesNotAliasExternalValue(
678 it.value(), &forOp.getRegion(),
679 forOp.getRegionIterArg(it.index()),
681 yieldValues.push_back(it.value());
685 rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
688 yieldValues.push_back(*alloc);
692 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
696 FailureOr<BaseMemRefType>
699 auto forOp = cast<scf::ForOp>(op);
701 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
703 if (
auto opResult = dyn_cast<OpResult>(value)) {
705 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
711 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
714 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
715 Value yieldedValue = yieldOp.getOperand(resultNum);
716 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
717 Value initArg = forOp.getInitArgs()[resultNum];
718 return computeLoopRegionIterArgBufferType(
719 op, iterArg, initArg, yieldedValue,
options, invocationStack);
724 auto forOp = cast<scf::ForOp>(op);
725 Block *oldLoopBody = forOp.getBody();
732 FailureOr<SmallVector<Value>> maybeInitArgs =
733 getBuffers(rewriter, forOp.getInitArgsMutable(),
options);
734 if (failed(maybeInitArgs))
741 Value initArg = it.value();
742 Value result = forOp->getResult(it.index());
744 if (!isa<TensorType>(result.
getType())) {
745 castedInitArgs.push_back(initArg);
749 if (failed(targetType))
751 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
755 auto newForOp = rewriter.
create<scf::ForOp>(
756 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
757 forOp.getStep(), castedInitArgs);
758 newForOp->
setAttrs(forOp->getAttrs());
759 Block *loopBody = newForOp.getBody();
765 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
766 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
769 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
784 LogicalResult verifyAnalysis(
Operation *op,
788 if (
options.allowReturnAllocsFromLoops)
791 auto forOp = cast<scf::ForOp>(op);
792 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
794 if (!isa<TensorType>(opResult.
getType()))
800 return yieldOp->emitError()
802 <<
" is not equivalent to the corresponding iter bbArg";
811 struct WhileOpInterface
812 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
828 auto whileOp = cast<scf::WhileOp>(op);
838 OpResult opResult = whileOp->getResult(idx);
840 return {{opResult, relation,
850 auto whileOp = cast<scf::WhileOp>(op);
853 if (resultNumber >= whileOp.getBeforeArguments().size())
856 whileOp.getBeforeArguments()[resultNumber].getType())
859 auto conditionOp = whileOp.getConditionOp();
860 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
861 Value conditionOperand = conditionOp.getArgs()[resultNumber];
862 bool equivCondition =
863 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
865 auto yieldOp = whileOp.getYieldOp();
866 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
867 Value yieldOperand = yieldOp.getOperand(resultNumber);
869 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
888 auto bufferizableOp = cast<BufferizableOpInterface>(op);
889 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
892 if (!state.getOptions().enforceAliasingInvariants)
903 auto whileOp = cast<scf::WhileOp>(op);
904 auto conditionOp = whileOp.getConditionOp();
909 whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
911 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
916 for (int64_t idx = 0;
917 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
918 Value value = conditionOp.getArgs()[idx];
919 if (!isa<TensorType>(value.
getType()) ||
920 (equivalentYieldsAfter.contains(idx) &&
921 equivalentYieldsBefore.contains(idx))) {
922 beforeYieldValues.push_back(value);
926 rewriter, conditionOp.getLoc(), value, state.getOptions());
929 beforeYieldValues.push_back(*alloc);
932 conditionOp.getArgsMutable().assign(beforeYieldValues);
940 auto whileOp = cast<scf::WhileOp>(op);
946 getTensorIndices(whileOp.getAfterArguments());
949 FailureOr<SmallVector<Value>> maybeInitArgs =
950 getBuffers(rewriter, whileOp.getInitsMutable(),
options);
951 if (failed(maybeInitArgs))
958 Value initArg = it.value();
959 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
961 if (!isa<TensorType>(beforeArg.
getType())) {
962 castedInitArgs.push_back(initArg);
966 if (failed(targetType))
968 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
973 llvm::map_range(whileOp.getAfterArguments(), [&](
BlockArgument bbArg) {
974 if (!isa<TensorType>(bbArg.getType()))
975 return bbArg.getType();
977 return llvm::cast<Type>(
978 *bufferization::getBufferType(bbArg, options));
983 TypeRange argsTypesBefore(argsRangeBefore);
984 auto newWhileOp = rewriter.
create<scf::WhileOp>(
985 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
992 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
993 newWhileOp.getBefore().
addArguments(argsTypesBefore, bbArgLocsBefore);
994 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
995 newWhileOp.getAfter().
addArguments(argsTypesAfter, bbArgLocsAfter);
1002 rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
1003 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1010 rewriter, newWhileOp.getAfterArguments(), indicesAfter);
1011 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1019 FailureOr<BaseMemRefType>
1022 auto whileOp = cast<scf::WhileOp>(op);
1024 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1027 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1030 auto yieldOp = whileOp.getYieldOp();
1032 return computeLoopRegionIterArgBufferType(
1033 op, bbArg, initArg, yieldedValue,
options, invocationStack);
1041 if (
auto opResult = dyn_cast<OpResult>(value)) {
1043 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1044 &whileOp.getAfter()) {
1045 resultNum = cast<BlockArgument>(value).getArgNumber();
1047 llvm_unreachable(
"invalid value");
1049 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1050 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1052 return cast<BaseMemRefType>(conditionYieldedVal.
getType());
1068 LogicalResult verifyAnalysis(
Operation *op,
1070 auto whileOp = cast<scf::WhileOp>(op);
1073 if (
options.allowReturnAllocsFromLoops)
1076 auto conditionOp = whileOp.getConditionOp();
1078 Block *block = conditionOp->getBlock();
1079 if (!isa<TensorType>(it.value().getType()))
1082 !state.areEquivalentBufferizedValues(it.value(),
1084 return conditionOp->emitError()
1085 <<
"Condition arg #" << it.index()
1086 <<
" is not equivalent to the corresponding iter bbArg";
1089 auto yieldOp = whileOp.getYieldOp();
1091 Block *block = yieldOp->getBlock();
1092 if (!isa<TensorType>(it.value().getType()))
1095 !state.areEquivalentBufferizedValues(it.value(),
1097 return yieldOp->emitError()
1098 <<
"Yield operand #" << it.index()
1099 <<
" is not equivalent to the corresponding iter bbArg";
1108 struct YieldOpInterface
1109 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1123 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1143 auto yieldOp = cast<scf::YieldOp>(op);
1144 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1145 scf::WhileOp>(yieldOp->getParentOp()))
1146 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1150 Value value = it.value();
1151 if (isa<TensorType>(value.
getType())) {
1153 if (failed(maybeBuffer))
1155 Value buffer = *maybeBuffer;
1157 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1158 yieldOp->getParentOp())) {
1160 yieldOp->getParentOp()->getResult(it.index()),
options);
1161 if (failed(resultType))
1163 buffer = castBuffer(rewriter, buffer, *resultType);
1164 }
else if (
auto whileOp =
1165 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1167 whileOp.getBeforeArguments()[it.index()],
options);
1168 if (failed(resultType))
1170 buffer = castBuffer(rewriter, buffer, *resultType);
1172 newResults.push_back(buffer);
1174 newResults.push_back(value);
1178 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1184 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1185 for (
auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1186 forallOp.getMixedUpperBound())) {
1189 if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1199 struct ForallOpInterface
1200 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1204 auto forallOp = cast<ForallOp>(op);
1209 if (mayHaveZeroIterations(forallOp))
1214 return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1225 auto forallOp = cast<ForallOp>(op);
1238 auto forallOp = cast<ForallOp>(op);
1239 int64_t rank = forallOp.getRank();
1243 for (
Value out : forallOp.getOutputs()) {
1247 buffers.push_back(*buffer);
1252 for (
const auto &it : llvm::zip(
1253 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1255 Value buffer = std::get<1>(it);
1256 Value bufferAsTensor =
1257 rewriter.
create<ToTensorOp>(forallOp.getLoc(), buffer);
1264 ForallOp newForallOp;
1265 newForallOp = rewriter.
create<ForallOp>(
1266 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1267 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1273 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1277 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1278 newForallOp.getBody()->getArguments().end());
1279 replacementBbArgs.append(forallOp.getOutputs().size(),
Value());
1280 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1289 FailureOr<BaseMemRefType>
1292 auto forallOp = cast<ForallOp>(op);
1294 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1298 forallOp.getTiedOpOperand(bbArg)->get(),
options, invocationStack);
1303 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1308 auto forallOp = cast<ForallOp>(op);
1312 for (
auto [lb, ub, step] :
1313 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1314 forallOp.getMixedStep())) {
1327 if (*lbConstant + *stepConstant < *ubConstant)
1333 bool isParallelRegion(
Operation *op,
unsigned index)
const {
1339 struct InParallelOpInterface
1340 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1344 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1356 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1357 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1358 ForOp::attachInterface<ForOpInterface>(*ctx);
1359 IfOp::attachInterface<IfOpInterface>(*ctx);
1360 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1361 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1362 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1363 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1364 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
BlockListType & getBlocks()
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...