35 assert(isa<BaseMemRefType>(type) &&
"expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.
getType()) &&
"expected BaseMemRefType");
43 assert(memref::CastOp::areCastCompatible(buffer.
getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return b.
create<memref::CastOp>(buffer.
getLoc(), type, buffer).getResult();
52 static bool doesNotAliasExternalValue(
Value value,
Region *region,
55 assert(llvm::hasSingleElement(region->
getBlocks()) &&
56 "expected region with single block");
58 state.applyOnAliases(value, [&](
Value alias) {
59 if (llvm::is_contained(exceptions, alias))
64 if (isa<OpResult>(alias) && !region->
isAncestor(aliasRegion))
71 struct ConditionOpInterface
72 :
public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
99 auto conditionOp = cast<scf::ConditionOp>(op);
100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
104 Value value = it.value();
105 if (isa<TensorType>(value.
getType())) {
107 if (failed(maybeBuffer))
110 whileOp.getAfterArguments()[it.index()],
options);
111 if (failed(resultType))
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
116 newArgs.push_back(value);
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 for (
Block &block : executeRegionOp.getRegion()) {
131 if (
auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
142 struct ExecuteRegionOpInterface
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146 static bool supportsUnstructuredControlFlow() {
return true; }
153 LogicalResult verifyAnalysis(
Operation *op,
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->
emitOpError(
"op without unique scf.yield is not supported");
165 if (
auto bbArg = dyn_cast<BlockArgument>(value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
174 assert(it != op->
getOpResults().end() &&
"invalid value");
175 size_t resultNum = std::distance(op->
getOpResults().begin(), it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
185 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186 auto yieldOp = getUniqueYieldOp(executeRegionOp);
187 TypeRange newResultTypes(yieldOp.getResults());
191 rewriter.
create<scf::ExecuteRegionOp>(op->
getLoc(), newResultTypes);
195 for (
Block &block : newOp.getRegion())
203 for (
const auto &it :
llvm::enumerate(executeRegionOp->getResultTypes())) {
204 if (isa<TensorType>(it.value())) {
205 newResults.push_back(rewriter.
create<bufferization::ToTensorOp>(
206 executeRegionOp.getLoc(), it.value(),
207 newOp->getResult(it.index())));
209 newResults.push_back(newOp->getResult(it.index()));
214 rewriter.
replaceOp(executeRegionOp, newResults);
222 :
public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
230 auto ifOp = cast<scf::IfOp>(op);
231 size_t resultNum = std::distance(op->
getOpResults().begin(),
233 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
234 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
242 auto ifOp = cast<scf::IfOp>(op);
246 for (
Value result : ifOp.getResults()) {
247 if (!isa<TensorType>(result.getType())) {
248 newTypes.push_back(result.getType());
252 if (failed(bufferType))
254 newTypes.push_back(*bufferType);
260 rewriter.
create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
264 rewriter.
mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
265 rewriter.
mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
273 FailureOr<BaseMemRefType>
276 auto ifOp = cast<scf::IfOp>(op);
277 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
278 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
282 auto opResult = cast<OpResult>(value);
283 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
284 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
286 if (isa<BaseMemRefType>(thenValue.getType())) {
288 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
290 auto maybeBufferType =
292 if (failed(maybeBufferType))
294 thenBufferType = *maybeBufferType;
296 if (isa<BaseMemRefType>(elseValue.getType())) {
298 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
300 auto maybeBufferType =
302 if (failed(maybeBufferType))
304 elseBufferType = *maybeBufferType;
308 if (thenBufferType == elseBufferType)
309 return thenBufferType;
313 return op->
emitError(
"inconsistent memory space on then/else branches");
317 cast<TensorType>(opResult.getType()), thenBufferType.
getMemorySpace());
323 struct IndexSwitchOpInterface
324 :
public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
325 scf::IndexSwitchOp> {
331 auto switchOp = cast<scf::IndexSwitchOp>(op);
332 int64_t resultNum = cast<OpResult>(value).getResultNumber();
334 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
336 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
341 auto defaultYieldOp =
342 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
352 auto switchOp = cast<scf::IndexSwitchOp>(op);
356 for (
Value result : switchOp.getResults()) {
357 if (!isa<TensorType>(result.getType())) {
358 newTypes.push_back(result.getType());
362 if (failed(bufferType))
364 newTypes.push_back(*bufferType);
369 auto newSwitchOp = rewriter.
create<scf::IndexSwitchOp>(
370 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
371 switchOp.getCases().size());
374 for (
auto [src, dest] :
375 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
378 newSwitchOp.getDefaultRegion(),
379 newSwitchOp.getDefaultRegion().begin());
387 FailureOr<BaseMemRefType>
390 auto switchOp = cast<scf::IndexSwitchOp>(op);
392 int64_t resultNum = cast<OpResult>(value).getResultNumber();
396 auto getYieldedBufferType = [&](
Block &b) -> FailureOr<BaseMemRefType> {
397 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
398 Value yieldedValue = yieldOp->getOperand(resultNum);
399 if (
auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.
getType()))
401 auto maybeBufferType =
403 if (failed(maybeBufferType))
405 return maybeBufferType;
409 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
410 if (failed(maybeBufferType))
415 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
416 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
417 if (failed(yieldedBufferType))
421 if (bufferType == *yieldedBufferType)
425 if (bufferType.
getMemorySpace() != yieldedBufferType->getMemorySpace())
426 return op->
emitError(
"inconsistent memory space on switch cases");
442 if (isa<TensorType>(it.value().getType()))
443 result.insert(it.index());
452 unsigned int minSize =
std::min(bbArgs.size(), yieldedValues.size());
454 for (
unsigned int i = 0; i < minSize; ++i) {
455 if (!isa<TensorType>(bbArgs[i].
getType()) ||
456 !isa<TensorType>(yieldedValues[i].
getType()))
458 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
466 static FailureOr<SmallVector<Value>>
471 if (isa<TensorType>(opOperand.get().getType())) {
472 FailureOr<Value> resultBuffer =
474 if (failed(resultBuffer))
476 result.push_back(*resultBuffer);
478 result.push_back(opOperand.get());
493 size_t idx = it.index();
494 Value val = it.value();
495 if (tensorIndices.contains(idx)) {
496 result.push_back(rewriter
497 .create<bufferization::ToTensorOp>(
498 val.
getLoc(), oldBbArgs[idx].getType(), val)
501 result.push_back(val);
519 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
523 auto initArgBufferType =
525 if (failed(initArgBufferType))
528 if (llvm::count(invocationStack, iterArg) >= 2) {
539 return *initArgBufferType;
544 if (isa<BaseMemRefType>(yieldedValue.
getType())) {
546 yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.
getType());
550 auto maybeBufferType =
552 if (failed(maybeBufferType))
554 yieldedValueBufferType = *maybeBufferType;
558 if (*initArgBufferType == yieldedValueBufferType)
559 return yieldedValueBufferType;
564 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
565 auto iterTensorType = cast<TensorType>(iterArg.
getType());
566 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
567 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
569 "init_arg and yielded value bufferize to inconsistent memory spaces");
571 if (
auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
573 llvm::all_equal({yieldedRankedBufferType.getShape(),
574 cast<MemRefType>(initBufferType).
getShape(),
575 cast<RankedTensorType>(iterTensorType).
getShape()}) &&
576 "expected same shape");
580 iterTensorType, yieldedBufferType.getMemorySpace());
584 bool mayHaveZeroIterations(scf::ForOp forOp) {
587 if (!lb.has_value() || !ub.has_value())
594 struct ForOpInterface
595 :
public BufferizableOpInterface::ExternalModel<ForOpInterface,
599 auto forOp = cast<scf::ForOp>(op);
603 if (mayHaveZeroIterations(forOp))
608 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
619 auto forOp = cast<scf::ForOp>(op);
620 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
622 return {{opResult, relation,
630 auto forOp = cast<scf::ForOp>(op);
631 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
632 bool equivalentYield = state.areEquivalentBufferizedValues(
633 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
651 auto bufferizableOp = cast<BufferizableOpInterface>(op);
652 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
655 if (state.getOptions().copyBeforeWrite)
663 auto forOp = cast<scf::ForOp>(op);
664 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
679 if (!indices.contains(it.index()) ||
680 doesNotAliasExternalValue(
681 it.value(), &forOp.getRegion(),
682 forOp.getRegionIterArg(it.index()),
684 yieldValues.push_back(it.value());
688 rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
691 yieldValues.push_back(*alloc);
695 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
699 FailureOr<BaseMemRefType>
702 auto forOp = cast<scf::ForOp>(op);
704 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
706 if (
auto opResult = dyn_cast<OpResult>(value)) {
708 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
714 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
717 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
718 Value yieldedValue = yieldOp.getOperand(resultNum);
719 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
720 Value initArg = forOp.getInitArgs()[resultNum];
721 return computeLoopRegionIterArgBufferType(
722 op, iterArg, initArg, yieldedValue,
options, invocationStack);
727 auto forOp = cast<scf::ForOp>(op);
728 Block *oldLoopBody = forOp.getBody();
735 FailureOr<SmallVector<Value>> maybeInitArgs =
736 getBuffers(rewriter, forOp.getInitArgsMutable(),
options);
737 if (failed(maybeInitArgs))
744 Value initArg = it.value();
745 Value result = forOp->getResult(it.index());
747 if (!isa<TensorType>(result.
getType())) {
748 castedInitArgs.push_back(initArg);
752 if (failed(targetType))
754 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
758 auto newForOp = rewriter.
create<scf::ForOp>(
759 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
760 forOp.getStep(), castedInitArgs);
761 newForOp->
setAttrs(forOp->getAttrs());
762 Block *loopBody = newForOp.getBody();
768 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
769 forOp.getRegionIterArgs(), indices);
770 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
773 rewriter.
mergeBlocks(oldLoopBody, loopBody, iterArgs);
788 LogicalResult verifyAnalysis(
Operation *op,
792 if (
options.allowReturnAllocsFromLoops)
795 auto forOp = cast<scf::ForOp>(op);
796 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
798 if (!isa<TensorType>(opResult.
getType()))
804 return yieldOp->emitError()
806 <<
" is not equivalent to the corresponding iter bbArg";
815 struct WhileOpInterface
816 :
public BufferizableOpInterface::ExternalModel<WhileOpInterface,
832 auto whileOp = cast<scf::WhileOp>(op);
842 OpResult opResult = whileOp->getResult(idx);
844 return {{opResult, relation,
854 auto whileOp = cast<scf::WhileOp>(op);
857 if (resultNumber >= whileOp.getBeforeArguments().size())
860 whileOp.getBeforeArguments()[resultNumber].getType())
863 auto conditionOp = whileOp.getConditionOp();
864 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
865 Value conditionOperand = conditionOp.getArgs()[resultNumber];
866 bool equivCondition =
867 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
869 auto yieldOp = whileOp.getYieldOp();
870 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
871 Value yieldOperand = yieldOp.getOperand(resultNumber);
873 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
892 auto bufferizableOp = cast<BufferizableOpInterface>(op);
893 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
896 if (state.getOptions().copyBeforeWrite)
907 auto whileOp = cast<scf::WhileOp>(op);
908 auto conditionOp = whileOp.getConditionOp();
913 whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
915 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
920 for (int64_t idx = 0;
921 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
922 Value value = conditionOp.getArgs()[idx];
923 if (!isa<TensorType>(value.
getType()) ||
924 (equivalentYieldsAfter.contains(idx) &&
925 equivalentYieldsBefore.contains(idx))) {
926 beforeYieldValues.push_back(value);
930 rewriter, conditionOp.getLoc(), value, state.getOptions());
933 beforeYieldValues.push_back(*alloc);
936 conditionOp.getArgsMutable().assign(beforeYieldValues);
944 auto whileOp = cast<scf::WhileOp>(op);
950 getTensorIndices(whileOp.getAfterArguments());
953 FailureOr<SmallVector<Value>> maybeInitArgs =
954 getBuffers(rewriter, whileOp.getInitsMutable(),
options);
955 if (failed(maybeInitArgs))
962 Value initArg = it.value();
963 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
965 if (!isa<TensorType>(beforeArg.
getType())) {
966 castedInitArgs.push_back(initArg);
970 if (failed(targetType))
972 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
977 llvm::map_range(whileOp.getAfterArguments(), [&](
BlockArgument bbArg) {
978 if (!isa<TensorType>(bbArg.getType()))
979 return bbArg.getType();
981 return llvm::cast<Type>(
982 *bufferization::getBufferType(bbArg, options));
987 TypeRange argsTypesBefore(argsRangeBefore);
988 auto newWhileOp = rewriter.
create<scf::WhileOp>(
989 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
996 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
997 newWhileOp.getBefore().
addArguments(argsTypesBefore, bbArgLocsBefore);
998 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
999 newWhileOp.getAfter().
addArguments(argsTypesAfter, bbArgLocsAfter);
1006 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1007 whileOp.getBeforeArguments(), indicesBefore);
1008 rewriter.
mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1015 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1016 whileOp.getAfterArguments(), indicesAfter);
1017 rewriter.
mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1025 FailureOr<BaseMemRefType>
1028 auto whileOp = cast<scf::WhileOp>(op);
1030 assert(isa<TensorType>(value.
getType()) &&
"expected tensor type");
1033 if (
auto bbArg = dyn_cast<BlockArgument>(value)) {
1036 auto yieldOp = whileOp.getYieldOp();
1038 return computeLoopRegionIterArgBufferType(
1039 op, bbArg, initArg, yieldedValue,
options, invocationStack);
1047 if (
auto opResult = dyn_cast<OpResult>(value)) {
1049 }
else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1050 &whileOp.getAfter()) {
1051 resultNum = cast<BlockArgument>(value).getArgNumber();
1053 llvm_unreachable(
"invalid value");
1055 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1056 if (!isa<TensorType>(conditionYieldedVal.
getType())) {
1058 return cast<BaseMemRefType>(conditionYieldedVal.
getType());
1074 LogicalResult verifyAnalysis(
Operation *op,
1076 auto whileOp = cast<scf::WhileOp>(op);
1079 if (
options.allowReturnAllocsFromLoops)
1082 auto conditionOp = whileOp.getConditionOp();
1084 Block *block = conditionOp->getBlock();
1085 if (!isa<TensorType>(it.value().getType()))
1088 !state.areEquivalentBufferizedValues(it.value(),
1090 return conditionOp->emitError()
1091 <<
"Condition arg #" << it.index()
1092 <<
" is not equivalent to the corresponding iter bbArg";
1095 auto yieldOp = whileOp.getYieldOp();
1097 Block *block = yieldOp->getBlock();
1098 if (!isa<TensorType>(it.value().getType()))
1101 !state.areEquivalentBufferizedValues(it.value(),
1103 return yieldOp->emitError()
1104 <<
"Yield operand #" << it.index()
1105 <<
" is not equivalent to the corresponding iter bbArg";
1114 struct YieldOpInterface
1115 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1129 if (
auto ifOp = dyn_cast<scf::IfOp>(op->
getParentOp())) {
1149 auto yieldOp = cast<scf::YieldOp>(op);
1150 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1151 scf::WhileOp>(yieldOp->getParentOp()))
1152 return yieldOp->emitError(
"unsupported scf::YieldOp parent");
1156 Value value = it.value();
1157 if (isa<TensorType>(value.
getType())) {
1159 if (failed(maybeBuffer))
1161 Value buffer = *maybeBuffer;
1163 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1164 yieldOp->getParentOp())) {
1166 yieldOp->getParentOp()->getResult(it.index()),
options);
1167 if (failed(resultType))
1169 buffer = castBuffer(rewriter, buffer, *resultType);
1170 }
else if (
auto whileOp =
1171 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1173 whileOp.getBeforeArguments()[it.index()],
options);
1174 if (failed(resultType))
1176 buffer = castBuffer(rewriter, buffer, *resultType);
1178 newResults.push_back(buffer);
1180 newResults.push_back(value);
1184 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1193 struct ForallOpInterface
1194 :
public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1213 auto forallOp = cast<ForallOp>(op);
1226 auto forallOp = cast<ForallOp>(op);
1227 int64_t rank = forallOp.getRank();
1231 for (
Value out : forallOp.getOutputs()) {
1235 buffers.push_back(*buffer);
1240 for (
const auto &it : llvm::zip(
1241 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1243 Value buffer = std::get<1>(it);
1244 Value bufferAsTensor = rewriter.
create<ToTensorOp>(
1245 forallOp.getLoc(), bbArg.
getType(), buffer);
1252 ForallOp newForallOp;
1253 newForallOp = rewriter.
create<ForallOp>(
1254 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1255 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1261 rewriter.
eraseOp(newForallOp.getBody()->getTerminator());
1265 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1266 newForallOp.getBody()->getArguments().end());
1267 replacementBbArgs.append(forallOp.getOutputs().size(),
Value());
1268 rewriter.
mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1277 FailureOr<BaseMemRefType>
1280 auto forallOp = cast<ForallOp>(op);
1282 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1286 forallOp.getTiedOpOperand(bbArg)->get(),
options, invocationStack);
1291 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()],
options,
1296 auto forallOp = cast<ForallOp>(op);
1300 for (
auto [lb, ub, step] :
1301 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1302 forallOp.getMixedStep())) {
1315 if (*lbConstant + *stepConstant < *ubConstant)
1321 bool isParallelRegion(
Operation *op,
unsigned index)
const {
1327 struct InParallelOpInterface
1328 :
public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1332 llvm_unreachable(
"op does not have any tensor OpOperands / OpResults");
1344 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1345 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1346 ForOp::attachInterface<ForOpInterface>(*ctx);
1347 IfOp::attachInterface<IfOpInterface>(*ctx);
1348 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1349 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1350 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1351 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1352 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getOpResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
BlockListType & getBlocks()
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...