17#include "llvm/ADT/SmallVectorExtras.h"
30 auto srcType = llvm::cast<MemRefType>(value.
getType());
33 if (srcType.getElementType() != destType.getElementType())
35 if (srcType.getRank() != destType.getRank())
41 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType
target) {
42 int64_t sourceOffset, targetOffset;
44 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
45 failed(
target.getStridesAndOffset(targetStrides, targetOffset)))
48 return ShapedType::isDynamic(a) && ShapedType::isStatic(
b);
50 if (dynamicToStatic(sourceOffset, targetOffset))
52 for (
auto it : zip(sourceStrides, targetStrides))
53 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
61 if (memref::CastOp::areCastCompatible(srcType, destType) &&
62 isGuaranteedCastCompatible(srcType, destType)) {
63 Value casted = memref::CastOp::create(
b, value.
getLoc(), destType, value);
69 for (
int i = 0; i < destType.getRank(); ++i) {
70 if (destType.getShape()[i] != ShapedType::kDynamic)
72 Value size = memref::DimOp::create(
b, loc, value, i);
73 dynamicOperands.push_back(size);
76 FailureOr<Value>
copy =
77 options.createAlloc(
b, loc, destType, dynamicOperands);
90 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
94 Type srcType = bufferToTensor.getBuffer().getType();
95 Type destType = toBuffer.getType();
98 if (srcType == destType) {
99 rewriter.
replaceOp(toBuffer, bufferToTensor.getBuffer());
103 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
104 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
105 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
108 if (rankedSrcType && rankedDestType) {
110 rewriter, bufferToTensor.getBuffer(), rankedDestType,
options);
120 if (unrankedSrcType && rankedDestType)
125 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126 "expected that types are cast compatible");
128 bufferToTensor.getBuffer());
135 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
136 for (
int64_t i = 0; i < shapedType.getRank(); ++i) {
137 if (shapedType.isDynamicDim(i)) {
138 if (llvm::isa<MemRefType>(shapedType)) {
139 dynamicDims.push_back(memref::DimOp::create(
b, loc, shapedValue, i));
141 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
142 dynamicDims.push_back(tensor::DimOp::create(
b, loc, shapedValue, i));
152LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
154 BufferizationState &state) {
159 if (getOperation()->getUses().empty()) {
160 rewriter.
eraseOp(getOperation());
167 FailureOr<Value> maybeCopyBuffer =
168 getBuffer(rewriter, getCopy(),
options, state);
169 if (failed(maybeCopyBuffer))
171 copyBuffer = *maybeCopyBuffer;
175 auto allocType = bufferization::getBufferType(getResult(),
options, state);
180 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
183 FailureOr<Value> alloc =
options.createAlloc(
184 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
190 if (
failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
195 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
200bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
203 return static_cast<bool>(getCopy());
206bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
209 "expected copy operand");
213bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
216 "expected copy operand");
220AliasingValueList AllocTensorOp::getAliasingValues(
OpOperand &opOperand,
226FailureOr<BufferLikeType>
228 const BufferizationState &state,
230 assert(value == getResult() &&
"invalid value");
234 if (getMemorySpace().has_value()) {
235 memorySpace = *getMemorySpace();
236 }
else if (getCopy()) {
237 auto copyBufferType =
238 bufferization::detail::asMemRefType(bufferization::getBufferType(
239 getCopy(),
options, state, invocationStack));
240 if (
failed(copyBufferType))
242 memorySpace = copyBufferType->getMemorySpace();
243 }
else if (
auto ms =
options.defaultMemorySpaceFn(
244 cast<TensorLikeType>(
getType()))) {
247 return getOperation()->emitError(
"could not infer memory space");
250 return cast<BufferLikeType>(
251 getMemRefTypeWithStaticIdentityLayout(
getType(), memorySpace));
254LogicalResult AllocTensorOp::verify() {
256 return emitError(
"dynamic sizes not needed when copying a tensor");
261 return emitError(
"expected that `copy` and return type match");
266 RankedTensorType type,
ValueRange dynamicSizes) {
267 build(builder,
result, type, dynamicSizes,
Value(),
273 RankedTensorType type,
ValueRange dynamicSizes,
281 IntegerAttr memorySpace) {
299 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
301 LogicalResult matchAndRewrite(AllocTensorOp op,
302 PatternRewriter &rewriter)
const override {
305 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
306 SmallVector<Value> newDynamicSizes;
307 unsigned int dynValCounter = 0;
308 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
309 if (!op.isDynamicDim(i))
311 Value value = op.getDynamicSizes()[dynValCounter++];
314 int64_t dim = intVal.getSExtValue();
316 newShape[i] = intVal.getSExtValue();
318 newDynamicSizes.push_back(value);
320 newDynamicSizes.push_back(value);
323 RankedTensorType newType = RankedTensorType::get(
324 newShape, op.getType().getElementType(), op.getType().getEncoding());
325 if (newType == op.getType())
327 auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
328 newDynamicSizes, Value());
335 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
337 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
338 PatternRewriter &rewriter)
const override {
339 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
340 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
341 if (!allocTensorOp || !maybeConstantIndex)
343 if (*maybeConstantIndex < 0 ||
344 *maybeConstantIndex >= allocTensorOp.getType().getRank())
346 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
349 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
357 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
360LogicalResult AllocTensorOp::reifyResultShapes(
363 llvm::map_to_vector<4>(llvm::seq<int64_t>(0,
getType().getRank()),
365 if (isDynamicDim(dim))
369 reifiedReturnShapes.emplace_back(std::move(shapes));
380 if (copyKeyword.succeeded())
386 if (sizeHintKeyword.succeeded())
400 if (copyKeyword.succeeded())
403 if (sizeHintKeyword.succeeded())
406 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
408 {static_cast<int32_t>(dynamicSizesOperands.size()),
409 static_cast<int32_t>(copyKeyword.succeeded()),
410 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
417 p <<
" copy(" << getCopy() <<
")";
419 p <<
" size_hint=" << getSizeHint();
421 AllocTensorOp::getOperandSegmentSizeAttr()});
423 auto type = getResult().getType();
424 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
431 assert(isDynamicDim(idx) &&
"expected dynamic dim");
433 return tensor::DimOp::create(
b, getLoc(), getCopy(), idx);
434 return getOperand(getIndexOfDynamicSize(idx));
450 using OpRewritePattern<CloneOp>::OpRewritePattern;
452 LogicalResult matchAndRewrite(CloneOp cloneOp,
453 PatternRewriter &rewriter)
const override {
454 if (cloneOp.use_empty()) {
459 Value source = cloneOp.getInput();
460 if (source.
getType() != cloneOp.getType() &&
461 !memref::CastOp::areCastCompatible({source.getType()},
462 {cloneOp.getType()}))
467 Value canonicalSource = source;
468 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
470 if (canonicalSource != iface.getViewDest()) {
473 canonicalSource = iface.getViewSource();
476 std::optional<Operation *> maybeCloneDeallocOp =
479 if (!maybeCloneDeallocOp.has_value())
481 std::optional<Operation *> maybeSourceDeallocOp =
483 if (!maybeSourceDeallocOp.has_value())
485 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
486 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
490 if (cloneDeallocOp && sourceDeallocOp &&
494 Block *currentBlock = cloneOp->getBlock();
495 Operation *redundantDealloc =
nullptr;
496 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
497 redundantDealloc = cloneDeallocOp;
498 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
499 redundantDealloc = sourceDeallocOp;
502 if (!redundantDealloc)
510 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
511 pos = pos->getNextNode()) {
515 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
516 if (!effectInterface)
518 if (effectInterface.hasEffect<MemoryEffects::Free>())
522 if (source.
getType() != cloneOp.getType())
523 source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
524 cloneOp.getType(), source);
526 rewriter.
eraseOp(redundantDealloc);
535 results.
add<SimplifyClones>(context);
542LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
544 BufferizationState &state) {
545 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(),
options, state);
548 memref::DeallocOp::create(rewriter, getLoc(), *buffer);
549 rewriter.
eraseOp(getOperation());
557bool MaterializeInDestinationOp::bufferizesToMemoryRead(
559 return opOperand == getSourceMutable();
562bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
564 if (opOperand == getDestMutable()) {
565 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
571bool MaterializeInDestinationOp::mustBufferizeInPlace(
580MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
582 if (opOperand == getDestMutable()) {
583 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
584 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
590MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
592 BufferizationState &state) {
593 bool tensorDest = isa<TensorType>(getDest().
getType());
596 FailureOr<Value> maybeBuffer =
597 getBuffer(rewriter, getDest(),
options, state);
600 buffer = *maybeBuffer;
602 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
605 auto srcBuffer = getBuffer(rewriter, getSource(),
options, state);
608 if (
failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
610 replaceOpWithBufferizedValues(rewriter, getOperation(),
615bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
622LogicalResult MaterializeInDestinationOp::reifyResultShapes(
624 if (getOperation()->getNumResults() == 1) {
625 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
626 reifiedReturnShapes.resize(1,
628 reifiedReturnShapes[0] =
634Value MaterializeInDestinationOp::buildSubsetExtraction(
OpBuilder &builder,
636 if (isa<TensorType>(getDest().
getType())) {
649 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
650 assert(getRestrict() &&
651 "expected that ops with memrefs dest have 'restrict'");
653 return ToTensorOp::create(
656 true, getWritable());
659bool MaterializeInDestinationOp::isEquivalentSubset(
661 return equivalenceFn(getDest(), candidate);
665MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
669OpOperand &MaterializeInDestinationOp::getSourceOperand() {
670 return getOperation()->getOpOperand(0) ;
673bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
674 SubsetOpInterface subsetOp,
679bool MaterializeInDestinationOp::operatesOnDisjointSubset(
680 SubsetOpInterface subsetOp,
685LogicalResult MaterializeInDestinationOp::verify() {
686 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
687 return emitOpError(
"'dest' must be a tensor or a memref");
688 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
689 if (getOperation()->getNumResults() != 1)
690 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
691 if (destType != getResult().
getType())
692 return emitOpError(
"result and 'dest' types must match");
694 if (isa<BaseMemRefType>(getDest().
getType()) &&
695 getOperation()->getNumResults() != 0)
696 return emitOpError(
"memref 'dest' implies zero results");
697 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
698 return emitOpError(
"'restrict' is valid only for memref destinations");
699 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
700 return emitOpError(
"'writable' must be specified if and only if the "
701 "destination is of memref type");
703 ShapedType destType = cast<ShapedType>(getDest().
getType());
704 if (srcType.
hasRank() != destType.hasRank())
705 return emitOpError(
"source/destination shapes are incompatible");
710 for (
auto [src, dest] :
711 llvm::zip(srcType.
getShape(), destType.getShape())) {
712 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
718 return emitOpError(
"source/destination shapes are incompatible");
724void MaterializeInDestinationOp::build(
OpBuilder &builder,
727 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
728 build(builder, state, destTensorType ? destTensorType :
Type(),
732bool MaterializeInDestinationOp::isWritable(
Value value,
734 return isa<TensorType>(getDest().
getType()) ?
true : getWritable();
738 return getDestMutable();
741void MaterializeInDestinationOp::getEffects(
744 if (isa<BaseMemRefType>(getDest().
getType()))
754 return getWritable();
758 if (
auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
761 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
762 toBuffer->getNextNode() == this->getOperation())
763 return toBuffer.getTensor();
769 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
771 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
772 PatternRewriter &rewriter)
const override {
773 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
774 if (!memrefToTensorOp)
778 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
786 results.
add<DimOfToTensorFolder>(context);
794 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
795 if (memrefToTensor.getBuffer().getType() ==
getType())
796 return memrefToTensor.getBuffer();
804 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
806 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
807 PatternRewriter &rewriter)
const final {
808 auto tensorCastOperand =
809 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
810 if (!tensorCastOperand)
812 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
813 tensorCastOperand.getOperand().getType());
816 auto currentOutputMemRefType =
817 dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
818 if (!currentOutputMemRefType)
821 auto memrefType = currentOutputMemRefType.cloneWith(
822 srcTensorType.getShape(), srcTensorType.getElementType());
823 Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
824 tensorCastOperand.getOperand(),
825 toBuffer.getReadOnly());
835 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
837 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
838 PatternRewriter &rewriter)
const final {
848 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
850 LogicalResult matchAndRewrite(memref::LoadOp
load,
851 PatternRewriter &rewriter)
const override {
852 auto toBuffer =
load.getMemref().getDefiningOp<ToBufferOp>();
853 if (!toBuffer || !toBuffer.getReadOnly())
864 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
866 LogicalResult matchAndRewrite(memref::DimOp dimOp,
867 PatternRewriter &rewriter)
const override {
868 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
871 Value newSource = castOp.getOperand();
882 results.
add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
883 ToBufferToTensorFolding>(context);
886LogicalResult ToBufferOp::bufferize(
RewriterBase &rewriter,
888 BufferizationState &state) {
896std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
898 return memref::DeallocOp::create(builder, alloc.
getLoc(), alloc)
902std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
903 return CloneOp::create(builder, alloc.
getLoc(), alloc).getResult();
910LogicalResult DeallocOp::inferReturnTypes(
911 MLIRContext *context, std::optional<::mlir::Location> location,
914 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
916 IntegerType::get(context, 1));
920LogicalResult DeallocOp::verify() {
921 if (getMemrefs().size() != getConditions().size())
923 "must have the same number of conditions as memrefs to deallocate");
924 if (getRetained().size() != getUpdatedConditions().size())
925 return emitOpError(
"must have the same number of updated conditions "
926 "(results) as retained operands");
934 if (deallocOp.getMemrefs() == memrefs &&
935 deallocOp.getConditions() == conditions)
939 deallocOp.getMemrefsMutable().assign(memrefs);
940 deallocOp.getConditionsMutable().assign(conditions);
960struct DeallocRemoveDuplicateDeallocMemrefs
962 using OpRewritePattern<DeallocOp>::OpRewritePattern;
964 LogicalResult matchAndRewrite(DeallocOp deallocOp,
965 PatternRewriter &rewriter)
const override {
968 SmallVector<Value> newMemrefs, newConditions;
969 for (
auto [i, memref, cond] :
970 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
971 if (memrefToCondition.count(memref)) {
974 Value &newCond = newConditions[memrefToCondition[memref]];
977 arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
979 memrefToCondition.insert({memref, newConditions.size()});
980 newMemrefs.push_back(memref);
981 newConditions.push_back(cond);
1002struct DeallocRemoveDuplicateRetainedMemrefs
1004 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1006 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1007 PatternRewriter &rewriter)
const override {
1010 SmallVector<Value> newRetained;
1011 SmallVector<unsigned> resultReplacementIdx;
1013 for (
auto retained : deallocOp.getRetained()) {
1014 if (seen.count(retained)) {
1015 resultReplacementIdx.push_back(seen[retained]);
1020 newRetained.push_back(retained);
1021 resultReplacementIdx.push_back(i++);
1026 if (newRetained.size() == deallocOp.getRetained().size())
1032 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1033 deallocOp.getConditions(), newRetained);
1034 SmallVector<Value> replacements(
1035 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1036 return newDeallocOp.getUpdatedConditions()[idx];
1038 rewriter.
replaceOp(deallocOp, replacements);
1049 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1051 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1052 PatternRewriter &rewriter)
const override {
1053 if (deallocOp.getMemrefs().empty()) {
1054 Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1057 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1078 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1080 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1081 PatternRewriter &rewriter)
const override {
1082 SmallVector<Value> newMemrefs, newConditions;
1083 for (
auto [memref, cond] :
1084 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1086 newMemrefs.push_back(memref);
1087 newConditions.push_back(cond);
1115 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1117 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1118 PatternRewriter &rewriter)
const override {
1119 SmallVector<Value> newMemrefs(
1120 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1121 auto extractStridedOp =
1122 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1123 if (!extractStridedOp)
1125 Value allocMemref = extractStridedOp.getOperand();
1126 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1129 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1135 deallocOp.getConditions(), rewriter);
1153struct RemoveAllocDeallocPairWhenNoOtherUsers
1155 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1157 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1158 PatternRewriter &rewriter)
const override {
1159 SmallVector<Value> newMemrefs, newConditions;
1160 SmallVector<Operation *> toDelete;
1161 for (
auto [memref, cond] :
1162 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1163 if (
auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1167 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1169 memref.hasOneUse()) {
1170 toDelete.push_back(allocOp);
1175 newMemrefs.push_back(memref);
1176 newConditions.push_back(cond);
1183 for (Operation *op : toDelete)
1199 patterns.
add<DeallocRemoveDuplicateDeallocMemrefs,
1200 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1201 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1202 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1209#define GET_OP_CLASSES
1210#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
true
Given two iterators into the same block, return "true" if a is before `b.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.