34 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
35 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
42 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
43 "scf.while op bufferization: cast incompatible");
44 return memref::CastOp::create(b, buffer.
getLoc(), type, buffer).getResult();
51 static bool doesNotAliasExternalValue(
Value value,
Region *region,
54 assert(region->
hasOneBlock() &&
"expected region with single block");
56 state.applyOnAliases(value, [&](
Value alias) {
57 if (llvm::is_contained(exceptions, alias))
62 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
69 struct ConditionOpInterface
70 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
98 auto conditionOp = cast<scf::ConditionOp>(op);
99 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
103 Value value = it.value();
104 if (isa<TensorType>(value.
getType())) {
105 FailureOr<Value> maybeBuffer =
110 whileOp.getAfterArguments()[it.index()],
options, state);
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
116 newArgs.push_back(value);
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 for (
Block &block : executeRegionOp.getRegion()) {
131 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
142 struct ExecuteRegionOpInterface
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146 static bool supportsUnstructuredControlFlow() {
return true; }
153 LogicalResult verifyAnalysis(
Operation *op,
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->
emitOpError(
"op without unique scf.yield is not supported");
165 if (
auto bbArg = dyn_cast<BlockArgument>(value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
174 assert(it != op->
getOpResults().end() &&
"invalid value");
175 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
186 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
187 auto yieldOp = getUniqueYieldOp(executeRegionOp);
188 TypeRange newResultTypes(yieldOp.getResults());
192 scf::ExecuteRegionOp::create(rewriter, op->
getLoc(), newResultTypes);
193 newOp.getRegion().takeBody(executeRegionOp.getRegion());
196 for (
Block &block : newOp.getRegion())
204 for (
const auto &it :
llvm::enumerate(executeRegionOp->getResultTypes())) {
205 if (isa<TensorType>(it.value())) {
206 newResults.push_back(bufferization::ToTensorOp::create(
207 rewriter, executeRegionOp.getLoc(), it.value(),
208 newOp->getResult(it.index())));
210 newResults.push_back(newOp->getResult(it.index()));
215 rewriter.
replaceOp(executeRegionOp, newResults);
223 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
231 auto ifOp = cast<scf::IfOp>(op);
232 size_t resultNum = std::distance(op->
getOpResults().begin(),
234 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
235 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
244 auto ifOp = cast<scf::IfOp>(op);
248 for (
Value result : ifOp.getResults()) {
249 if (!isa<TensorType>(result.getType())) {
250 newTypes.push_back(result.getType());
256 newTypes.push_back(*bufferType);
261 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
266 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
267 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
275 FailureOr<BufferLikeType>
279 auto ifOp = cast<scf::IfOp>(op);
280 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
281 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
285 auto opResult = cast<OpResult>(value);
286 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
287 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
289 if (isa<BaseMemRefType>(thenValue.getType())) {
291 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
293 auto maybeBufferType =
295 thenValue,
options, state, invocationStack));
296 if (
failed(maybeBufferType))
298 thenBufferType = *maybeBufferType;
300 if (isa<BaseMemRefType>(elseValue.getType())) {
302 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
304 auto maybeBufferType =
306 elseValue,
options, state, invocationStack));
307 if (
failed(maybeBufferType))
309 elseBufferType = *maybeBufferType;
313 if (thenBufferType == elseBufferType)
314 return cast<BufferLikeType>(thenBufferType);
318 return op->
emitError(
"inconsistent memory space on then/else branches");
322 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace()));
328 struct IndexSwitchOpInterface
329 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330 scf::IndexSwitchOp> {
336 auto switchOp = cast<scf::IndexSwitchOp>(op);
337 int64_t resultNum = cast<OpResult>(value).getResultNumber();
339 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
341 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
346 auto defaultYieldOp =
347 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
358 auto switchOp = cast<scf::IndexSwitchOp>(op);
362 for (
Value result : switchOp.getResults()) {
363 if (!isa<TensorType>(result.getType())) {
364 newTypes.push_back(result.getType());
370 newTypes.push_back(*bufferType);
375 auto newSwitchOp = scf::IndexSwitchOp::create(
376 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
377 switchOp.getCases(), switchOp.getCases().size());
380 for (
auto [src, dest] :
381 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
384 newSwitchOp.getDefaultRegion(),
385 newSwitchOp.getDefaultRegion().begin());
393 FailureOr<BufferLikeType>
397 auto switchOp = cast<scf::IndexSwitchOp>(op);
399 int64_t resultNum = cast<OpResult>(value).getResultNumber();
402 auto getYieldedBufferType = [&](
Block &b) -> FailureOr<BaseMemRefType> {
403 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
404 Value yieldedValue = yieldOp->getOperand(resultNum);
405 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
408 yieldedValue,
options, state, invocationStack);
413 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
414 if (
failed(maybeBufferType))
419 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
420 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
421 if (
failed(yieldedBufferType))
425 if (bufferType == *yieldedBufferType)
429 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
430 return op->
emitError(
"inconsistent memory space on switch cases");
437 return cast<BufferLikeType>(bufferType);
446 if (isa<TensorType>(it.value().getType()))
447 result.insert(it.index());
456 unsigned int minSize =
std::min(bbArgs.size(), yieldedValues.size());
458 for (
unsigned int i = 0; i < minSize; ++i) {
459 if (!isa<TensorType>(bbArgs[i].
getType()) ||
460 !isa<TensorType>(yieldedValues[i].
getType()))
462 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
470 static FailureOr<SmallVector<Value>>
475 if (isa<TensorType>(opOperand.get().getType())) {
476 FailureOr<Value> resultBuffer =
480 result.push_back(*resultBuffer);
482 result.push_back(opOperand.get());
497 size_t idx = it.index();
498 Value val = it.value();
499 if (tensorIndices.contains(idx)) {
501 bufferization::ToTensorOp::create(rewriter, val.
getLoc(),
502 oldBbArgs[idx].getType(), val)
505 result.push_back(val);
523 static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
528 auto initArgBufferType =
530 if (
failed(initArgBufferType))
533 if (llvm::count(invocationStack, iterArg) >= 2) {
544 return *initArgBufferType;
548 BufferLikeType yieldedValueBufferType;
549 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
551 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.
getType());
556 state, invocationStack);
557 if (
failed(maybeBufferType))
559 yieldedValueBufferType = *maybeBufferType;
563 if (*initArgBufferType == yieldedValueBufferType)
564 return yieldedValueBufferType;
569 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
570 auto iterTensorType = cast<TensorType>(iterArg.
getType());
571 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
572 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
574 "init_arg and yielded value bufferize to inconsistent memory spaces");
576 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
578 llvm::all_equal({yieldedRankedBufferType.getShape(),
579 cast<MemRefType>(initBufferType).
getShape(),
580 cast<RankedTensorType>(iterTensorType).
getShape()}) &&
581 "expected same shape");
585 iterTensorType, yieldedBufferType.getMemorySpace()));
589 bool mayHaveZeroIterations(scf::ForOp forOp) {
592 if (!lb.has_value() || !ub.has_value())
599 struct ForOpInterface
600 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
604 auto forOp = cast<scf::ForOp>(op);
608 if (mayHaveZeroIterations(forOp))
613 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
624 auto forOp = cast<scf::ForOp>(op);
625 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
627 return {{opResult, relation,
635 auto forOp = cast<scf::ForOp>(op);
636 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
637 bool equivalentYield = state.areEquivalentBufferizedValues(
638 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
658 auto bufferizableOp = cast<BufferizableOpInterface>(op);
659 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
660 rewriter, analysisState, bufferizationState)))
663 if (analysisState.getOptions().copyBeforeWrite)
671 auto forOp = cast<scf::ForOp>(op);
672 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
687 if (!indices.contains(it.index()) ||
688 doesNotAliasExternalValue(
689 it.value(), &forOp.getRegion(),
690 forOp.getRegionIterArg(it.index()),
692 yieldValues.push_back(it.value());
696 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
700 yieldValues.push_back(*alloc);
704 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
708 FailureOr<BufferLikeType>
712 auto forOp = cast<scf::ForOp>(op);
714 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
716 if (
auto opResult = dyn_cast<OpResult>(value)) {
718 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
725 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
728 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
729 Value yieldedValue = yieldOp.getOperand(resultNum);
730 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
731 Value initArg = forOp.getInitArgs()[resultNum];
732 return computeLoopRegionIterArgBufferType(
733 op, iterArg, initArg, yieldedValue,
options, state, invocationStack);
739 auto forOp = cast<scf::ForOp>(op);
740 Block *oldLoopBody = forOp.getBody();
747 FailureOr<SmallVector<Value>> maybeInitArgs =
748 getBuffers(rewriter, forOp.getInitArgsMutable(),
options, state);
749 if (
failed(maybeInitArgs))
756 Value initArg = it.value();
757 Value result = forOp->getResult(it.index());
759 if (!isa<TensorType>(result.
getType())) {
760 castedInitArgs.push_back(initArg);
766 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
770 auto newForOp = scf::ForOp::create(
771 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
772 forOp.getStep(), castedInitArgs,
nullptr,
773 forOp.getUnsignedCmp());
774 newForOp->setAttrs(forOp->getAttrs());
775 Block *loopBody = newForOp.getBody();
781 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
782 forOp.getRegionIterArgs(), indices);
783 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
786 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
801 LogicalResult verifyAnalysis(
Operation *op,
805 if (
options.allowReturnAllocsFromLoops)
808 auto forOp = cast<scf::ForOp>(op);
809 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
811 if (!isa<TensorType>(opResult.
getType()))
817 return yieldOp->emitError()
819 <<
" is not equivalent to the corresponding iter bbArg";
828 struct WhileOpInterface
829 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
845 auto whileOp = cast<scf::WhileOp>(op);
855 OpResult opResult = whileOp->getResult(idx);
857 return {{opResult, relation,
867 auto whileOp = cast<scf::WhileOp>(op);
870 if (resultNumber >= whileOp.getBeforeArguments().size())
873 whileOp.getBeforeArguments()[resultNumber].getType())
876 auto conditionOp = whileOp.getConditionOp();
877 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
878 Value conditionOperand = conditionOp.getArgs()[resultNumber];
879 bool equivCondition =
880 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
882 auto yieldOp = whileOp.getYieldOp();
883 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
884 Value yieldOperand = yieldOp.getOperand(resultNumber);
886 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
907 auto bufferizableOp = cast<BufferizableOpInterface>(op);
908 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
909 rewriter, analysisState, bufferizationState)))
912 if (analysisState.getOptions().copyBeforeWrite)
923 auto whileOp = cast<scf::WhileOp>(op);
924 auto conditionOp = whileOp.getConditionOp();
929 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
931 getEquivalentBuffers(whileOp.getAfterArguments(),
932 whileOp.getYieldOp().getResults(), analysisState);
937 for (int64_t idx = 0;
938 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
939 Value value = conditionOp.getArgs()[idx];
940 if (!isa<TensorType>(value.
getType()) ||
941 (equivalentYieldsAfter.contains(idx) &&
942 equivalentYieldsBefore.contains(idx))) {
943 beforeYieldValues.push_back(value);
947 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
951 beforeYieldValues.push_back(*alloc);
954 conditionOp.getArgsMutable().assign(beforeYieldValues);
963 auto whileOp = cast<scf::WhileOp>(op);
969 getTensorIndices(whileOp.getAfterArguments());
972 FailureOr<SmallVector<Value>> maybeInitArgs =
973 getBuffers(rewriter, whileOp.getInitsMutable(),
options, state);
974 if (
failed(maybeInitArgs))
981 Value initArg = it.value();
982 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
984 if (!isa<TensorType>(beforeArg.
getType())) {
985 castedInitArgs.push_back(initArg);
991 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
996 llvm::map_range(whileOp.getAfterArguments(), [&](
BlockArgument bbArg) {
997 if (!isa<TensorType>(bbArg.getType()))
998 return bbArg.getType();
1000 return llvm::cast<Type>(
1001 *bufferization::getBufferType(bbArg, options, state));
1006 TypeRange argsTypesBefore(argsRangeBefore);
1007 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1008 argsTypesAfter, castedInitArgs);
1015 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1016 newWhileOp.getBefore().
addArguments(argsTypesBefore, bbArgLocsBefore);
1017 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1018 newWhileOp.getAfter().
addArguments(argsTypesAfter, bbArgLocsAfter);
1025 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1026 whileOp.getBeforeArguments(), indicesBefore);
1027 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1034 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1035 whileOp.getAfterArguments(), indicesAfter);
1036 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1044 FailureOr<BufferLikeType>
1048 auto whileOp = cast<scf::WhileOp>(op);
1050 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1053 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1056 auto yieldOp = whileOp.getYieldOp();
1058 return computeLoopRegionIterArgBufferType(
1059 op, bbArg, initArg, yieldedValue,
options, state, invocationStack);
1067 if (
auto opResult = dyn_cast<OpResult>(value)) {
1069 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1070 &whileOp.getAfter()) {
1071 resultNum = cast<BlockArgument>(value).getArgNumber();
1073 llvm_unreachable(
"invalid value");
1075 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1076 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1078 return cast<BufferLikeType>(conditionYieldedVal.
getType());
1094 LogicalResult verifyAnalysis(
Operation *op,
1096 auto whileOp = cast<scf::WhileOp>(op);
1099 if (
options.allowReturnAllocsFromLoops)
1102 auto conditionOp = whileOp.getConditionOp();
1104 Block *block = conditionOp->getBlock();
1105 if (!isa<TensorType>(it.value().getType()))
1108 !state.areEquivalentBufferizedValues(it.value(),
1110 return conditionOp->emitError()
1111 <<
"Condition arg #" << it.index()
1112 <<
" is not equivalent to the corresponding iter bbArg";
1115 auto yieldOp = whileOp.getYieldOp();
1117 Block *block = yieldOp->getBlock();
1118 if (!isa<TensorType>(it.value().getType()))
1121 !state.areEquivalentBufferizedValues(it.value(),
1123 return yieldOp->emitError()
1124 <<
"Yield operand #" << it.index()
1125 <<
" is not equivalent to the corresponding iter bbArg";
1134 struct YieldOpInterface
1135 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1149 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1170 auto yieldOp = cast<scf::YieldOp>(op);
1171 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1172 scf::WhileOp>(yieldOp->getParentOp()))
1173 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1177 Value value = it.value();
1178 if (isa<TensorType>(value.
getType())) {
1179 FailureOr<Value> maybeBuffer =
1183 Value buffer = *maybeBuffer;
1185 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1186 yieldOp->getParentOp())) {
1188 yieldOp->getParentOp()->getResult(it.index()),
options, state);
1191 buffer = castBuffer(rewriter, buffer, *resultType);
1192 }
else if (
auto whileOp =
1193 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1195 whileOp.getBeforeArguments()[it.index()],
options, state);
1198 buffer = castBuffer(rewriter, buffer, *resultType);
1200 newResults.push_back(buffer);
1202 newResults.push_back(value);
1206 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1215 struct ForallOpInterface
1216 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1235 auto forallOp = cast<ForallOp>(op);
1249 auto forallOp = cast<ForallOp>(op);
1250 int64_t rank = forallOp.getRank();
1254 for (
Value out : forallOp.getOutputs()) {
1258 buffers.push_back(*buffer);
1263 for (
const auto &it : llvm::zip(
1264 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1266 Value buffer = std::get<1>(it);
1267 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1275 ForallOp newForallOp;
1276 newForallOp = ForallOp::create(
1277 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1278 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1284 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1288 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1289 newForallOp.getBody()->getArguments().end());
1290 replacementBbArgs.append(forallOp.getOutputs().size(),
Value());
1291 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1300 FailureOr<BufferLikeType>
1304 auto forallOp = cast<ForallOp>(op);
1306 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1310 forallOp.getTiedOpOperand(bbArg)->get(),
options, state,
1316 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1317 state, invocationStack);
1321 auto forallOp = cast<ForallOp>(op);
1325 for (
auto [lb, ub, step] :
1326 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1327 forallOp.getMixedStep())) {
1340 if (*lbConstant + *stepConstant < *ubConstant)
1346 bool isParallelRegion(
Operation *op,
unsigned index)
const {
1352 struct InParallelOpInterface
1353 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1358 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1370 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1371 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1372 ForOp::attachInterface<ForOpInterface>(*ctx);
1373 IfOp::attachInterface<IfOpInterface>(*ctx);
1374 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1375 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1376 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1377 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1378 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...