28 auto srcType = llvm::cast<MemRefType>(value.
getType());
31 if (srcType.getElementType() != destType.getElementType())
33 if (srcType.getRank() != destType.getRank())
39 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType
target) {
40 int64_t sourceOffset, targetOffset;
42 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
43 failed(
target.getStridesAndOffset(targetStrides, targetOffset)))
46 return ShapedType::isDynamic(a) && ShapedType::isStatic(
b);
48 if (dynamicToStatic(sourceOffset, targetOffset))
50 for (
auto it : zip(sourceStrides, targetStrides))
51 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
59 if (memref::CastOp::areCastCompatible(srcType, destType) &&
60 isGuaranteedCastCompatible(srcType, destType)) {
61 Value casted = memref::CastOp::create(
b, value.
getLoc(), destType, value);
67 for (
int i = 0; i < destType.getRank(); ++i) {
68 if (destType.getShape()[i] != ShapedType::kDynamic)
70 Value size = memref::DimOp::create(
b, loc, value, i);
71 dynamicOperands.push_back(size);
74 FailureOr<Value>
copy =
75 options.createAlloc(
b, loc, destType, dynamicOperands);
88 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
92 Type srcType = bufferToTensor.getBuffer().getType();
93 Type destType = toBuffer.getType();
96 if (srcType == destType) {
97 rewriter.
replaceOp(toBuffer, bufferToTensor.getBuffer());
101 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
102 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
103 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
106 if (rankedSrcType && rankedDestType) {
108 rewriter, bufferToTensor.getBuffer(), rankedDestType,
options);
118 if (unrankedSrcType && rankedDestType)
123 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
124 "expected that types are cast compatible");
126 bufferToTensor.getBuffer());
133 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
134 for (
int64_t i = 0; i < shapedType.getRank(); ++i) {
135 if (shapedType.isDynamicDim(i)) {
136 if (llvm::isa<MemRefType>(shapedType)) {
137 dynamicDims.push_back(memref::DimOp::create(
b, loc, shapedValue, i));
139 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
140 dynamicDims.push_back(tensor::DimOp::create(
b, loc, shapedValue, i));
150LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
152 BufferizationState &state) {
157 if (getOperation()->getUses().empty()) {
158 rewriter.
eraseOp(getOperation());
165 FailureOr<Value> maybeCopyBuffer =
166 getBuffer(rewriter, getCopy(),
options, state);
167 if (failed(maybeCopyBuffer))
169 copyBuffer = *maybeCopyBuffer;
173 auto allocType = bufferization::getBufferType(getResult(),
options, state);
178 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
181 FailureOr<Value> alloc =
options.createAlloc(
182 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
188 if (
failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
193 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
198bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
201 return static_cast<bool>(getCopy());
204bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
207 "expected copy operand");
211bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
214 "expected copy operand");
218AliasingValueList AllocTensorOp::getAliasingValues(
OpOperand &opOperand,
224FailureOr<BufferLikeType>
226 const BufferizationState &state,
228 assert(value == getResult() &&
"invalid value");
232 if (getMemorySpace().has_value()) {
233 memorySpace = *getMemorySpace();
234 }
else if (getCopy()) {
235 auto copyBufferType =
236 bufferization::detail::asMemRefType(bufferization::getBufferType(
237 getCopy(),
options, state, invocationStack));
238 if (
failed(copyBufferType))
240 memorySpace = copyBufferType->getMemorySpace();
244 return getOperation()->emitError(
"could not infer memory space");
247 return cast<BufferLikeType>(
248 getMemRefTypeWithStaticIdentityLayout(
getType(), memorySpace));
251LogicalResult AllocTensorOp::verify() {
253 return emitError(
"dynamic sizes not needed when copying a tensor");
256 <<
getType().getNumDynamicDims() <<
" dynamic sizes";
258 return emitError(
"expected that `copy` and return type match");
263 RankedTensorType type,
ValueRange dynamicSizes) {
264 build(builder,
result, type, dynamicSizes,
Value(),
270 RankedTensorType type,
ValueRange dynamicSizes,
278 IntegerAttr memorySpace) {
296 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
298 LogicalResult matchAndRewrite(AllocTensorOp op,
299 PatternRewriter &rewriter)
const override {
302 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
303 SmallVector<Value> newDynamicSizes;
304 unsigned int dynValCounter = 0;
305 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
306 if (!op.isDynamicDim(i))
308 Value value = op.getDynamicSizes()[dynValCounter++];
311 int64_t dim = intVal.getSExtValue();
313 newShape[i] = intVal.getSExtValue();
315 newDynamicSizes.push_back(value);
317 newDynamicSizes.push_back(value);
320 RankedTensorType newType = RankedTensorType::get(
321 newShape, op.getType().getElementType(), op.getType().getEncoding());
322 if (newType == op.getType())
324 auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
325 newDynamicSizes, Value());
332 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
334 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
335 PatternRewriter &rewriter)
const override {
336 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
337 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
338 if (!allocTensorOp || !maybeConstantIndex)
340 if (*maybeConstantIndex < 0 ||
341 *maybeConstantIndex >= allocTensorOp.getType().getRank())
343 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
346 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
354 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
357LogicalResult AllocTensorOp::reifyResultShapes(
359 auto shapes = llvm::to_vector<4>(
360 llvm::map_range(llvm::seq<int64_t>(0,
getType().getRank()),
362 if (isDynamicDim(dim))
366 reifiedReturnShapes.emplace_back(std::move(shapes));
377 if (copyKeyword.succeeded())
383 if (sizeHintKeyword.succeeded())
397 if (copyKeyword.succeeded())
400 if (sizeHintKeyword.succeeded())
403 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
405 {static_cast<int32_t>(dynamicSizesOperands.size()),
406 static_cast<int32_t>(copyKeyword.succeeded()),
407 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
414 p <<
" copy(" << getCopy() <<
")";
416 p <<
" size_hint=" << getSizeHint();
418 AllocTensorOp::getOperandSegmentSizeAttr()});
420 auto type = getResult().getType();
421 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
428 assert(isDynamicDim(idx) &&
"expected dynamic dim");
430 return tensor::DimOp::create(
b, getLoc(), getCopy(), idx);
431 return getOperand(getIndexOfDynamicSize(idx));
447 using OpRewritePattern<CloneOp>::OpRewritePattern;
449 LogicalResult matchAndRewrite(CloneOp cloneOp,
450 PatternRewriter &rewriter)
const override {
451 if (cloneOp.use_empty()) {
456 Value source = cloneOp.getInput();
457 if (source.
getType() != cloneOp.getType() &&
458 !memref::CastOp::areCastCompatible({source.getType()},
459 {cloneOp.getType()}))
464 Value canonicalSource = source;
465 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
467 if (canonicalSource != iface.getViewDest()) {
470 canonicalSource = iface.getViewSource();
473 std::optional<Operation *> maybeCloneDeallocOp =
476 if (!maybeCloneDeallocOp.has_value())
478 std::optional<Operation *> maybeSourceDeallocOp =
480 if (!maybeSourceDeallocOp.has_value())
482 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
483 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
487 if (cloneDeallocOp && sourceDeallocOp &&
491 Block *currentBlock = cloneOp->getBlock();
492 Operation *redundantDealloc =
nullptr;
493 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
494 redundantDealloc = cloneDeallocOp;
495 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
496 redundantDealloc = sourceDeallocOp;
499 if (!redundantDealloc)
507 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
508 pos = pos->getNextNode()) {
512 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
513 if (!effectInterface)
515 if (effectInterface.hasEffect<MemoryEffects::Free>())
519 if (source.
getType() != cloneOp.getType())
520 source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
521 cloneOp.getType(), source);
523 rewriter.
eraseOp(redundantDealloc);
532 results.
add<SimplifyClones>(context);
539LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
541 BufferizationState &state) {
542 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(),
options, state);
545 memref::DeallocOp::create(rewriter, getLoc(), *buffer);
546 rewriter.
eraseOp(getOperation());
554bool MaterializeInDestinationOp::bufferizesToMemoryRead(
556 return opOperand == getSourceMutable();
559bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
561 if (opOperand == getDestMutable()) {
562 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
568bool MaterializeInDestinationOp::mustBufferizeInPlace(
577MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
579 if (opOperand == getDestMutable()) {
580 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
581 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
587MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
589 BufferizationState &state) {
590 bool tensorDest = isa<TensorType>(getDest().
getType());
593 FailureOr<Value> maybeBuffer =
594 getBuffer(rewriter, getDest(),
options, state);
597 buffer = *maybeBuffer;
599 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
602 auto srcBuffer = getBuffer(rewriter, getSource(),
options, state);
605 if (
failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
607 replaceOpWithBufferizedValues(rewriter, getOperation(),
612bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
619LogicalResult MaterializeInDestinationOp::reifyResultShapes(
621 if (getOperation()->getNumResults() == 1) {
622 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
623 reifiedReturnShapes.resize(1,
625 reifiedReturnShapes[0] =
631Value MaterializeInDestinationOp::buildSubsetExtraction(
OpBuilder &builder,
633 if (isa<TensorType>(getDest().
getType())) {
646 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
647 assert(getRestrict() &&
648 "expected that ops with memrefs dest have 'restrict'");
650 return ToTensorOp::create(
653 true, getWritable());
656bool MaterializeInDestinationOp::isEquivalentSubset(
658 return equivalenceFn(getDest(), candidate);
662MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
666OpOperand &MaterializeInDestinationOp::getSourceOperand() {
667 return getOperation()->getOpOperand(0) ;
670bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
671 SubsetOpInterface subsetOp,
676bool MaterializeInDestinationOp::operatesOnDisjointSubset(
677 SubsetOpInterface subsetOp,
682LogicalResult MaterializeInDestinationOp::verify() {
683 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
684 return emitOpError(
"'dest' must be a tensor or a memref");
685 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
686 if (getOperation()->getNumResults() != 1)
687 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
688 if (destType != getResult().
getType())
689 return emitOpError(
"result and 'dest' types must match");
691 if (isa<BaseMemRefType>(getDest().
getType()) &&
692 getOperation()->getNumResults() != 0)
693 return emitOpError(
"memref 'dest' implies zero results");
694 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
695 return emitOpError(
"'restrict' is valid only for memref destinations");
696 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
697 return emitOpError(
"'writable' must be specified if and only if the "
698 "destination is of memref type");
700 ShapedType destType = cast<ShapedType>(getDest().
getType());
701 if (srcType.
hasRank() != destType.hasRank())
702 return emitOpError(
"source/destination shapes are incompatible");
704 if (srcType.getRank() != destType.getRank())
705 return emitOpError(
"rank mismatch between source and destination shape");
706 for (
auto [src, dest] :
707 llvm::zip(srcType.
getShape(), destType.getShape())) {
708 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
714 return emitOpError(
"source/destination shapes are incompatible");
720void MaterializeInDestinationOp::build(
OpBuilder &builder,
723 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
724 build(builder, state, destTensorType ? destTensorType :
Type(),
728bool MaterializeInDestinationOp::isWritable(
Value value,
730 return isa<TensorType>(getDest().
getType()) ?
true : getWritable();
734 return getDestMutable();
737void MaterializeInDestinationOp::getEffects(
740 if (isa<BaseMemRefType>(getDest().
getType()))
750 return getWritable();
754 if (
auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
757 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
758 toBuffer->getNextNode() == this->getOperation())
759 return toBuffer.getTensor();
765 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
767 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
768 PatternRewriter &rewriter)
const override {
769 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
770 if (!memrefToTensorOp)
774 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
782 results.
add<DimOfToTensorFolder>(context);
790 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
791 if (memrefToTensor.getBuffer().getType() ==
getType())
792 return memrefToTensor.getBuffer();
800 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
802 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
803 PatternRewriter &rewriter)
const final {
804 auto tensorCastOperand =
805 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
806 if (!tensorCastOperand)
808 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
809 tensorCastOperand.getOperand().getType());
812 auto currentOutputMemRefType =
813 dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
814 if (!currentOutputMemRefType)
817 auto memrefType = currentOutputMemRefType.cloneWith(
818 srcTensorType.getShape(), srcTensorType.getElementType());
819 Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
820 tensorCastOperand.getOperand(),
821 toBuffer.getReadOnly());
831 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
833 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
834 PatternRewriter &rewriter)
const final {
844 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
846 LogicalResult matchAndRewrite(memref::LoadOp
load,
847 PatternRewriter &rewriter)
const override {
848 auto toBuffer =
load.getMemref().getDefiningOp<ToBufferOp>();
860 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
862 LogicalResult matchAndRewrite(memref::DimOp dimOp,
863 PatternRewriter &rewriter)
const override {
864 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
867 Value newSource = castOp.getOperand();
878 results.
add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
879 ToBufferToTensorFolding>(context);
882LogicalResult ToBufferOp::bufferize(
RewriterBase &rewriter,
884 BufferizationState &state) {
892std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
894 return memref::DeallocOp::create(builder, alloc.
getLoc(), alloc)
898std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
899 return CloneOp::create(builder, alloc.
getLoc(), alloc).getResult();
906LogicalResult DeallocOp::inferReturnTypes(
907 MLIRContext *context, std::optional<::mlir::Location> location,
910 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
912 IntegerType::get(context, 1));
916LogicalResult DeallocOp::verify() {
917 if (getMemrefs().size() != getConditions().size())
919 "must have the same number of conditions as memrefs to deallocate");
920 if (getRetained().size() != getUpdatedConditions().size())
921 return emitOpError(
"must have the same number of updated conditions "
922 "(results) as retained operands");
930 if (deallocOp.getMemrefs() == memrefs &&
931 deallocOp.getConditions() == conditions)
935 deallocOp.getMemrefsMutable().assign(memrefs);
936 deallocOp.getConditionsMutable().assign(conditions);
956struct DeallocRemoveDuplicateDeallocMemrefs
958 using OpRewritePattern<DeallocOp>::OpRewritePattern;
960 LogicalResult matchAndRewrite(DeallocOp deallocOp,
961 PatternRewriter &rewriter)
const override {
964 SmallVector<Value> newMemrefs, newConditions;
965 for (
auto [i, memref, cond] :
966 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
967 if (memrefToCondition.count(memref)) {
970 Value &newCond = newConditions[memrefToCondition[memref]];
973 arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
975 memrefToCondition.insert({memref, newConditions.size()});
976 newMemrefs.push_back(memref);
977 newConditions.push_back(cond);
998struct DeallocRemoveDuplicateRetainedMemrefs
1000 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1002 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1003 PatternRewriter &rewriter)
const override {
1006 SmallVector<Value> newRetained;
1007 SmallVector<unsigned> resultReplacementIdx;
1009 for (
auto retained : deallocOp.getRetained()) {
1010 if (seen.count(retained)) {
1011 resultReplacementIdx.push_back(seen[retained]);
1016 newRetained.push_back(retained);
1017 resultReplacementIdx.push_back(i++);
1022 if (newRetained.size() == deallocOp.getRetained().size())
1028 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1029 deallocOp.getConditions(), newRetained);
1030 SmallVector<Value> replacements(
1031 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1032 return newDeallocOp.getUpdatedConditions()[idx];
1034 rewriter.
replaceOp(deallocOp, replacements);
1045 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1047 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1048 PatternRewriter &rewriter)
const override {
1049 if (deallocOp.getMemrefs().empty()) {
1050 Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1053 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1074 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1076 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1077 PatternRewriter &rewriter)
const override {
1078 SmallVector<Value> newMemrefs, newConditions;
1079 for (
auto [memref, cond] :
1080 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1082 newMemrefs.push_back(memref);
1083 newConditions.push_back(cond);
1111 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1113 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1114 PatternRewriter &rewriter)
const override {
1115 SmallVector<Value> newMemrefs(
1116 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1117 auto extractStridedOp =
1118 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1119 if (!extractStridedOp)
1121 Value allocMemref = extractStridedOp.getOperand();
1122 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1125 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1131 deallocOp.getConditions(), rewriter);
1149struct RemoveAllocDeallocPairWhenNoOtherUsers
1151 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1153 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1154 PatternRewriter &rewriter)
const override {
1155 SmallVector<Value> newMemrefs, newConditions;
1156 SmallVector<Operation *> toDelete;
1157 for (
auto [memref, cond] :
1158 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1159 if (
auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1163 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1165 memref.hasOneUse()) {
1166 toDelete.push_back(allocOp);
1171 newMemrefs.push_back(memref);
1172 newConditions.push_back(cond);
1179 for (Operation *op : toDelete)
1195 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1196 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1197 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1198 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1205#define GET_OP_CLASSES
1206#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
true
Given two iterators into the same block, return "true" if a is before `b.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.