35 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
43 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return b.
create<memref::CastOp>(buffer.
getLoc(), type, buffer).getResult();
52 static bool doesNotAliasExternalValue(
Value value,
Region *region,
55 assert(llvm::hasSingleElement(region->
getBlocks()) &&
56 "expected region with single block");
58 state.applyOnAliases(value, [&](
Value alias) {
59 if (llvm::is_contained(exceptions, alias))
64 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
71 struct ConditionOpInterface
72 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
100 auto conditionOp = cast<scf::ConditionOp>(op);
101 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
105 Value value = it.value();
106 if (isa<TensorType>(value.
getType())) {
107 FailureOr<Value> maybeBuffer =
109 if (failed(maybeBuffer))
112 whileOp.getAfterArguments()[it.index()],
options, state);
113 if (failed(resultType))
115 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
116 newArgs.push_back(buffer);
118 newArgs.push_back(value);
122 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
123 rewriter, op, conditionOp.getCondition(), newArgs);
130 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
132 for (
Block &block : executeRegionOp.getRegion()) {
133 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
144 struct ExecuteRegionOpInterface
146 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
148 static bool supportsUnstructuredControlFlow() {
return true; }
155 LogicalResult verifyAnalysis(
Operation *op,
157 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
159 if (!getUniqueYieldOp(executeRegionOp))
160 return op->
emitOpError(
"op without unique scf.yield is not supported");
167 if (
auto bbArg = dyn_cast<BlockArgument>(value))
168 return getAliasingBranchOpOperands(op, bbArg, state);
174 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
176 assert(it != op->
getOpResults().end() &&
"invalid value");
177 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
178 auto yieldOp = getUniqueYieldOp(executeRegionOp);
188 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
189 auto yieldOp = getUniqueYieldOp(executeRegionOp);
190 TypeRange newResultTypes(yieldOp.getResults());
194 rewriter.
create<scf::ExecuteRegionOp>(op->
getLoc(), newResultTypes);
198 for (
Block &block : newOp.getRegion())
206 for (
const auto &it :
llvm::enumerate(executeRegionOp->getResultTypes())) {
207 if (isa<TensorType>(it.value())) {
208 newResults.push_back(rewriter.
create<bufferization::ToTensorOp>(
209 executeRegionOp.getLoc(), it.value(),
210 newOp->getResult(it.index())));
212 newResults.push_back(newOp->getResult(it.index()));
217 rewriter.
replaceOp(executeRegionOp, newResults);
225 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
233 auto ifOp = cast<scf::IfOp>(op);
234 size_t resultNum = std::distance(op->
getOpResults().begin(),
236 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
237 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
246 auto ifOp = cast<scf::IfOp>(op);
250 for (
Value result : ifOp.getResults()) {
251 if (!isa<TensorType>(result.getType())) {
252 newTypes.push_back(result.getType());
256 if (failed(bufferType))
258 newTypes.push_back(*bufferType);
264 rewriter.
create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
268 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
269 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
277 FailureOr<BaseMemRefType>
281 auto ifOp = cast<scf::IfOp>(op);
282 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
283 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
287 auto opResult = cast<OpResult>(value);
288 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
289 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
291 if (isa<BaseMemRefType>(thenValue.getType())) {
293 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
296 thenValue,
options, state, invocationStack);
297 if (failed(maybeBufferType))
299 thenBufferType = *maybeBufferType;
301 if (isa<BaseMemRefType>(elseValue.getType())) {
303 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
306 elseValue,
options, state, invocationStack);
307 if (failed(maybeBufferType))
309 elseBufferType = *maybeBufferType;
313 if (thenBufferType == elseBufferType)
314 return thenBufferType;
318 return op->
emitError(
"inconsistent memory space on then/else branches");
322 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace());
328 struct IndexSwitchOpInterface
329 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330 scf::IndexSwitchOp> {
336 auto switchOp = cast<scf::IndexSwitchOp>(op);
337 int64_t resultNum = cast<OpResult>(value).getResultNumber();
339 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
341 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
346 auto defaultYieldOp =
347 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
358 auto switchOp = cast<scf::IndexSwitchOp>(op);
362 for (
Value result : switchOp.getResults()) {
363 if (!isa<TensorType>(result.getType())) {
364 newTypes.push_back(result.getType());
368 if (failed(bufferType))
370 newTypes.push_back(*bufferType);
375 auto newSwitchOp = rewriter.
create<scf::IndexSwitchOp>(
376 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
377 switchOp.getCases().size());
380 for (
auto [src, dest] :
381 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
384 newSwitchOp.getDefaultRegion(),
385 newSwitchOp.getDefaultRegion().begin());
393 FailureOr<BaseMemRefType>
397 auto switchOp = cast<scf::IndexSwitchOp>(op);
399 int64_t resultNum = cast<OpResult>(value).getResultNumber();
402 auto getYieldedBufferType = [&](
Block &b) -> FailureOr<BaseMemRefType> {
403 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
404 Value yieldedValue = yieldOp->getOperand(resultNum);
405 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
408 yieldedValue,
options, state, invocationStack);
409 if (failed(maybeBufferType))
411 return maybeBufferType;
415 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
416 if (failed(maybeBufferType))
421 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
422 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
423 if (failed(yieldedBufferType))
427 if (bufferType == *yieldedBufferType)
431 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
432 return op->
emitError(
"inconsistent memory space on switch cases");
448 if (isa<TensorType>(it.value().getType()))
449 result.insert(it.index());
458 unsigned int minSize =
std::min(bbArgs.size(), yieldedValues.size());
460 for (
unsigned int i = 0; i < minSize; ++i) {
461 if (!isa<TensorType>(bbArgs[i].
getType()) ||
462 !isa<TensorType>(yieldedValues[i].
getType()))
464 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
472 static FailureOr<SmallVector<Value>>
477 if (isa<TensorType>(opOperand.get().getType())) {
478 FailureOr<Value> resultBuffer =
480 if (failed(resultBuffer))
482 result.push_back(*resultBuffer);
484 result.push_back(opOperand.get());
499 size_t idx = it.index();
500 Value val = it.value();
501 if (tensorIndices.contains(idx)) {
502 result.push_back(rewriter
503 .create<bufferization::ToTensorOp>(
504 val.
getLoc(), oldBbArgs[idx].getType(), val)
507 result.push_back(val);
525 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
530 auto initArgBufferType =
532 if (failed(initArgBufferType))
535 if (llvm::count(invocationStack, iterArg) >= 2) {
546 return *initArgBufferType;
551 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
553 yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.
getType());
558 state, invocationStack);
559 if (failed(maybeBufferType))
561 yieldedValueBufferType = *maybeBufferType;
565 if (*initArgBufferType == yieldedValueBufferType)
566 return yieldedValueBufferType;
571 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
572 auto iterTensorType = cast<TensorType>(iterArg.
getType());
573 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
574 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
576 "init_arg and yielded value bufferize to inconsistent memory spaces");
578 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
580 llvm::all_equal({yieldedRankedBufferType.getShape(),
581 cast<MemRefType>(initBufferType).
getShape(),
582 cast<RankedTensorType>(iterTensorType).
getShape()}) &&
583 "expected same shape");
587 iterTensorType, yieldedBufferType.getMemorySpace());
591 bool mayHaveZeroIterations(scf::ForOp forOp) {
594 if (!lb.has_value() || !ub.has_value())
601 struct ForOpInterface
602 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
606 auto forOp = cast<scf::ForOp>(op);
610 if (mayHaveZeroIterations(forOp))
615 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
626 auto forOp = cast<scf::ForOp>(op);
627 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
629 return {{opResult, relation,
637 auto forOp = cast<scf::ForOp>(op);
638 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
639 bool equivalentYield = state.areEquivalentBufferizedValues(
640 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
660 auto bufferizableOp = cast<BufferizableOpInterface>(op);
661 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
662 rewriter, analysisState, bufferizationState)))
665 if (analysisState.getOptions().copyBeforeWrite)
673 auto forOp = cast<scf::ForOp>(op);
674 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
689 if (!indices.contains(it.index()) ||
690 doesNotAliasExternalValue(
691 it.value(), &forOp.getRegion(),
692 forOp.getRegionIterArg(it.index()),
694 yieldValues.push_back(it.value());
698 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
702 yieldValues.push_back(*alloc);
706 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
710 FailureOr<BaseMemRefType>
714 auto forOp = cast<scf::ForOp>(op);
716 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
718 if (
auto opResult = dyn_cast<OpResult>(value)) {
720 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
727 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
730 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
731 Value yieldedValue = yieldOp.getOperand(resultNum);
732 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
733 Value initArg = forOp.getInitArgs()[resultNum];
734 return computeLoopRegionIterArgBufferType(
735 op, iterArg, initArg, yieldedValue,
options, state, invocationStack);
741 auto forOp = cast<scf::ForOp>(op);
742 Block *oldLoopBody = forOp.getBody();
749 FailureOr<SmallVector<Value>> maybeInitArgs =
750 getBuffers(rewriter, forOp.getInitArgsMutable(),
options, state);
751 if (failed(maybeInitArgs))
758 Value initArg = it.value();
759 Value result = forOp->getResult(it.index());
761 if (!isa<TensorType>(result.
getType())) {
762 castedInitArgs.push_back(initArg);
766 if (failed(targetType))
768 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
772 auto newForOp = rewriter.
create<scf::ForOp>(
773 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
774 forOp.getStep(), castedInitArgs);
775 newForOp->
setAttrs(forOp->getAttrs());
776 Block *loopBody = newForOp.getBody();
782 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
783 forOp.getRegionIterArgs(), indices);
784 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
787 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
802 LogicalResult verifyAnalysis(
Operation *op,
806 if (
options.allowReturnAllocsFromLoops)
809 auto forOp = cast<scf::ForOp>(op);
810 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
812 if (!isa<TensorType>(opResult.
getType()))
818 return yieldOp->emitError()
820 <<
" is not equivalent to the corresponding iter bbArg";
829 struct WhileOpInterface
830 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
846 auto whileOp = cast<scf::WhileOp>(op);
856 OpResult opResult = whileOp->getResult(idx);
858 return {{opResult, relation,
868 auto whileOp = cast<scf::WhileOp>(op);
871 if (resultNumber >= whileOp.getBeforeArguments().size())
874 whileOp.getBeforeArguments()[resultNumber].getType())
877 auto conditionOp = whileOp.getConditionOp();
878 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
879 Value conditionOperand = conditionOp.getArgs()[resultNumber];
880 bool equivCondition =
881 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
883 auto yieldOp = whileOp.getYieldOp();
884 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
885 Value yieldOperand = yieldOp.getOperand(resultNumber);
887 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
908 auto bufferizableOp = cast<BufferizableOpInterface>(op);
909 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
910 rewriter, analysisState, bufferizationState)))
913 if (analysisState.getOptions().copyBeforeWrite)
924 auto whileOp = cast<scf::WhileOp>(op);
925 auto conditionOp = whileOp.getConditionOp();
930 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
932 getEquivalentBuffers(whileOp.getAfterArguments(),
933 whileOp.getYieldOp().getResults(), analysisState);
938 for (int64_t idx = 0;
939 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
940 Value value = conditionOp.getArgs()[idx];
941 if (!isa<TensorType>(value.
getType()) ||
942 (equivalentYieldsAfter.contains(idx) &&
943 equivalentYieldsBefore.contains(idx))) {
944 beforeYieldValues.push_back(value);
948 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
952 beforeYieldValues.push_back(*alloc);
955 conditionOp.getArgsMutable().assign(beforeYieldValues);
964 auto whileOp = cast<scf::WhileOp>(op);
970 getTensorIndices(whileOp.getAfterArguments());
973 FailureOr<SmallVector<Value>> maybeInitArgs =
974 getBuffers(rewriter, whileOp.getInitsMutable(),
options, state);
975 if (failed(maybeInitArgs))
982 Value initArg = it.value();
983 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
985 if (!isa<TensorType>(beforeArg.
getType())) {
986 castedInitArgs.push_back(initArg);
990 if (failed(targetType))
992 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
997 llvm::map_range(whileOp.getAfterArguments(), [&](
BlockArgument bbArg) {
998 if (!isa<TensorType>(bbArg.getType()))
999 return bbArg.getType();
1001 return llvm::cast<Type>(
1002 *bufferization::getBufferType(bbArg, options, state));
1007 TypeRange argsTypesBefore(argsRangeBefore);
1008 auto newWhileOp = rewriter.
create<scf::WhileOp>(
1009 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
1016 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1017 newWhileOp.getBefore().
addArguments(argsTypesBefore, bbArgLocsBefore);
1018 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1019 newWhileOp.getAfter().
addArguments(argsTypesAfter, bbArgLocsAfter);
1026 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1027 whileOp.getBeforeArguments(), indicesBefore);
1028 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1035 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1036 whileOp.getAfterArguments(), indicesAfter);
1037 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1045 FailureOr<BaseMemRefType>
1049 auto whileOp = cast<scf::WhileOp>(op);
1051 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1054 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1057 auto yieldOp = whileOp.getYieldOp();
1059 return computeLoopRegionIterArgBufferType(
1060 op, bbArg, initArg, yieldedValue,
options, state, invocationStack);
1068 if (
auto opResult = dyn_cast<OpResult>(value)) {
1070 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1071 &whileOp.getAfter()) {
1072 resultNum = cast<BlockArgument>(value).getArgNumber();
1074 llvm_unreachable(
"invalid value");
1076 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1077 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1079 return cast<BaseMemRefType>(conditionYieldedVal.
getType());
1095 LogicalResult verifyAnalysis(
Operation *op,
1097 auto whileOp = cast<scf::WhileOp>(op);
1100 if (
options.allowReturnAllocsFromLoops)
1103 auto conditionOp = whileOp.getConditionOp();
1105 Block *block = conditionOp->getBlock();
1106 if (!isa<TensorType>(it.value().getType()))
1109 !state.areEquivalentBufferizedValues(it.value(),
1111 return conditionOp->emitError()
1112 <<
"Condition arg #" << it.index()
1113 <<
" is not equivalent to the corresponding iter bbArg";
1116 auto yieldOp = whileOp.getYieldOp();
1118 Block *block = yieldOp->getBlock();
1119 if (!isa<TensorType>(it.value().getType()))
1122 !state.areEquivalentBufferizedValues(it.value(),
1124 return yieldOp->emitError()
1125 <<
"Yield operand #" << it.index()
1126 <<
" is not equivalent to the corresponding iter bbArg";
1135 struct YieldOpInterface
1136 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1150 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1171 auto yieldOp = cast<scf::YieldOp>(op);
1172 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1173 scf::WhileOp>(yieldOp->getParentOp()))
1174 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1178 Value value = it.value();
1179 if (isa<TensorType>(value.
getType())) {
1180 FailureOr<Value> maybeBuffer =
1182 if (failed(maybeBuffer))
1184 Value buffer = *maybeBuffer;
1186 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1187 yieldOp->getParentOp())) {
1189 yieldOp->getParentOp()->getResult(it.index()),
options, state);
1190 if (failed(resultType))
1192 buffer = castBuffer(rewriter, buffer, *resultType);
1193 }
else if (
auto whileOp =
1194 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1196 whileOp.getBeforeArguments()[it.index()],
options, state);
1197 if (failed(resultType))
1199 buffer = castBuffer(rewriter, buffer, *resultType);
1201 newResults.push_back(buffer);
1203 newResults.push_back(value);
1207 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1216 struct ForallOpInterface
1217 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1236 auto forallOp = cast<ForallOp>(op);
1250 auto forallOp = cast<ForallOp>(op);
1251 int64_t rank = forallOp.getRank();
1255 for (
Value out : forallOp.getOutputs()) {
1259 buffers.push_back(*buffer);
1264 for (
const auto &it : llvm::zip(
1265 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1267 Value buffer = std::get<1>(it);
1268 Value bufferAsTensor = rewriter.
create<ToTensorOp>(
1269 forallOp.getLoc(), bbArg.
getType(), buffer);
1276 ForallOp newForallOp;
1277 newForallOp = rewriter.
create<ForallOp>(
1278 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1279 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1285 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1289 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1290 newForallOp.getBody()->getArguments().end());
1291 replacementBbArgs.append(forallOp.getOutputs().size(),
Value());
1292 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1301 FailureOr<BaseMemRefType>
1305 auto forallOp = cast<ForallOp>(op);
1307 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1311 forallOp.getTiedOpOperand(bbArg)->get(),
options, state,
1317 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1318 state, invocationStack);
1322 auto forallOp = cast<ForallOp>(op);
1326 for (
auto [lb, ub, step] :
1327 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1328 forallOp.getMixedStep())) {
1341 if (*lbConstant + *stepConstant < *ubConstant)
1347 bool isParallelRegion(
Operation *op,
unsigned index)
const {
1353 struct InParallelOpInterface
1354 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1359 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1371 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1372 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1373 ForOp::attachInterface<ForOpInterface>(*ctx);
1374 IfOp::attachInterface<IfOpInterface>(*ctx);
1375 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1376 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1377 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1378 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1379 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
BlockListType & getBlocks()
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...