17#include "llvm/ADT/SmallVectorExtras.h"
30 auto srcType = llvm::cast<MemRefType>(value.
getType());
33 if (srcType.getElementType() != destType.getElementType())
35 if (srcType.getRank() != destType.getRank())
41 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType
target) {
42 int64_t sourceOffset, targetOffset;
44 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
45 failed(
target.getStridesAndOffset(targetStrides, targetOffset)))
48 return ShapedType::isDynamic(a) && ShapedType::isStatic(
b);
50 if (dynamicToStatic(sourceOffset, targetOffset))
52 for (
auto it : zip(sourceStrides, targetStrides))
53 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
61 if (memref::CastOp::areCastCompatible(srcType, destType) &&
62 isGuaranteedCastCompatible(srcType, destType)) {
63 Value casted = memref::CastOp::create(
b, value.
getLoc(), destType, value);
69 for (
int i = 0; i < destType.getRank(); ++i) {
70 if (destType.getShape()[i] != ShapedType::kDynamic)
72 Value size = memref::DimOp::create(
b, loc, value, i);
73 dynamicOperands.push_back(size);
76 FailureOr<Value>
copy =
77 options.createAlloc(
b, loc, destType, dynamicOperands);
90 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
94 Type srcType = bufferToTensor.getBuffer().getType();
95 Type destType = toBuffer.getType();
98 if (srcType == destType) {
99 rewriter.
replaceOp(toBuffer, bufferToTensor.getBuffer());
103 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
104 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
105 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
108 if (rankedSrcType && rankedDestType) {
110 rewriter, bufferToTensor.getBuffer(), rankedDestType,
options);
120 if (unrankedSrcType && rankedDestType)
125 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126 "expected that types are cast compatible");
128 bufferToTensor.getBuffer());
135 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
136 for (
int64_t i = 0; i < shapedType.getRank(); ++i) {
137 if (shapedType.isDynamicDim(i)) {
138 if (llvm::isa<MemRefType>(shapedType)) {
139 dynamicDims.push_back(memref::DimOp::create(
b, loc, shapedValue, i));
141 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
142 dynamicDims.push_back(tensor::DimOp::create(
b, loc, shapedValue, i));
152LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
154 BufferizationState &state) {
159 if (getOperation()->getUses().empty()) {
160 rewriter.
eraseOp(getOperation());
167 FailureOr<Value> maybeCopyBuffer =
168 getBuffer(rewriter, getCopy(),
options, state);
169 if (failed(maybeCopyBuffer))
171 copyBuffer = *maybeCopyBuffer;
175 auto allocType = bufferization::getBufferType(getResult(),
options, state);
180 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
183 FailureOr<Value> alloc =
options.createAlloc(
184 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
190 if (
failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
195 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
200bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
203 return static_cast<bool>(getCopy());
206bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
209 "expected copy operand");
213bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
216 "expected copy operand");
220AliasingValueList AllocTensorOp::getAliasingValues(
OpOperand &opOperand,
226FailureOr<BufferLikeType>
228 const BufferizationState &state,
230 assert(value == getResult() &&
"invalid value");
234 if (getMemorySpace().has_value()) {
235 memorySpace = *getMemorySpace();
236 }
else if (getCopy()) {
237 auto copyBufferType =
238 bufferization::detail::asMemRefType(bufferization::getBufferType(
239 getCopy(),
options, state, invocationStack));
240 if (
failed(copyBufferType))
242 memorySpace = copyBufferType->getMemorySpace();
246 return getOperation()->emitError(
"could not infer memory space");
249 return cast<BufferLikeType>(
250 getMemRefTypeWithStaticIdentityLayout(
getType(), memorySpace));
253LogicalResult AllocTensorOp::verify() {
255 return emitError(
"dynamic sizes not needed when copying a tensor");
260 return emitError(
"expected that `copy` and return type match");
265 RankedTensorType type,
ValueRange dynamicSizes) {
266 build(builder,
result, type, dynamicSizes,
Value(),
272 RankedTensorType type,
ValueRange dynamicSizes,
280 IntegerAttr memorySpace) {
298 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
300 LogicalResult matchAndRewrite(AllocTensorOp op,
301 PatternRewriter &rewriter)
const override {
304 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
305 SmallVector<Value> newDynamicSizes;
306 unsigned int dynValCounter = 0;
307 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
308 if (!op.isDynamicDim(i))
310 Value value = op.getDynamicSizes()[dynValCounter++];
313 int64_t dim = intVal.getSExtValue();
315 newShape[i] = intVal.getSExtValue();
317 newDynamicSizes.push_back(value);
319 newDynamicSizes.push_back(value);
322 RankedTensorType newType = RankedTensorType::get(
323 newShape, op.getType().getElementType(), op.getType().getEncoding());
324 if (newType == op.getType())
326 auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
327 newDynamicSizes, Value());
334 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
336 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
337 PatternRewriter &rewriter)
const override {
338 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
339 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
340 if (!allocTensorOp || !maybeConstantIndex)
342 if (*maybeConstantIndex < 0 ||
343 *maybeConstantIndex >= allocTensorOp.getType().getRank())
345 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
348 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
356 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
359LogicalResult AllocTensorOp::reifyResultShapes(
362 llvm::map_to_vector<4>(llvm::seq<int64_t>(0,
getType().getRank()),
364 if (isDynamicDim(dim))
368 reifiedReturnShapes.emplace_back(std::move(shapes));
379 if (copyKeyword.succeeded())
385 if (sizeHintKeyword.succeeded())
399 if (copyKeyword.succeeded())
402 if (sizeHintKeyword.succeeded())
405 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
407 {static_cast<int32_t>(dynamicSizesOperands.size()),
408 static_cast<int32_t>(copyKeyword.succeeded()),
409 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
416 p <<
" copy(" << getCopy() <<
")";
418 p <<
" size_hint=" << getSizeHint();
420 AllocTensorOp::getOperandSegmentSizeAttr()});
422 auto type = getResult().getType();
423 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
430 assert(isDynamicDim(idx) &&
"expected dynamic dim");
432 return tensor::DimOp::create(
b, getLoc(), getCopy(), idx);
433 return getOperand(getIndexOfDynamicSize(idx));
449 using OpRewritePattern<CloneOp>::OpRewritePattern;
451 LogicalResult matchAndRewrite(CloneOp cloneOp,
452 PatternRewriter &rewriter)
const override {
453 if (cloneOp.use_empty()) {
458 Value source = cloneOp.getInput();
459 if (source.
getType() != cloneOp.getType() &&
460 !memref::CastOp::areCastCompatible({source.getType()},
461 {cloneOp.getType()}))
466 Value canonicalSource = source;
467 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
469 if (canonicalSource != iface.getViewDest()) {
472 canonicalSource = iface.getViewSource();
475 std::optional<Operation *> maybeCloneDeallocOp =
478 if (!maybeCloneDeallocOp.has_value())
480 std::optional<Operation *> maybeSourceDeallocOp =
482 if (!maybeSourceDeallocOp.has_value())
484 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
485 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
489 if (cloneDeallocOp && sourceDeallocOp &&
493 Block *currentBlock = cloneOp->getBlock();
494 Operation *redundantDealloc =
nullptr;
495 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
496 redundantDealloc = cloneDeallocOp;
497 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
498 redundantDealloc = sourceDeallocOp;
501 if (!redundantDealloc)
509 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
510 pos = pos->getNextNode()) {
514 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
515 if (!effectInterface)
517 if (effectInterface.hasEffect<MemoryEffects::Free>())
521 if (source.
getType() != cloneOp.getType())
522 source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
523 cloneOp.getType(), source);
525 rewriter.
eraseOp(redundantDealloc);
534 results.
add<SimplifyClones>(context);
541LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
543 BufferizationState &state) {
544 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(),
options, state);
547 memref::DeallocOp::create(rewriter, getLoc(), *buffer);
548 rewriter.
eraseOp(getOperation());
556bool MaterializeInDestinationOp::bufferizesToMemoryRead(
558 return opOperand == getSourceMutable();
561bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
563 if (opOperand == getDestMutable()) {
564 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
570bool MaterializeInDestinationOp::mustBufferizeInPlace(
579MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
581 if (opOperand == getDestMutable()) {
582 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
583 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
589MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
591 BufferizationState &state) {
592 bool tensorDest = isa<TensorType>(getDest().
getType());
595 FailureOr<Value> maybeBuffer =
596 getBuffer(rewriter, getDest(),
options, state);
599 buffer = *maybeBuffer;
601 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
604 auto srcBuffer = getBuffer(rewriter, getSource(),
options, state);
607 if (
failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
609 replaceOpWithBufferizedValues(rewriter, getOperation(),
614bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
621LogicalResult MaterializeInDestinationOp::reifyResultShapes(
623 if (getOperation()->getNumResults() == 1) {
624 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
625 reifiedReturnShapes.resize(1,
627 reifiedReturnShapes[0] =
633Value MaterializeInDestinationOp::buildSubsetExtraction(
OpBuilder &builder,
635 if (isa<TensorType>(getDest().
getType())) {
648 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
649 assert(getRestrict() &&
650 "expected that ops with memrefs dest have 'restrict'");
652 return ToTensorOp::create(
655 true, getWritable());
658bool MaterializeInDestinationOp::isEquivalentSubset(
660 return equivalenceFn(getDest(), candidate);
664MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
668OpOperand &MaterializeInDestinationOp::getSourceOperand() {
669 return getOperation()->getOpOperand(0) ;
672bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
673 SubsetOpInterface subsetOp,
678bool MaterializeInDestinationOp::operatesOnDisjointSubset(
679 SubsetOpInterface subsetOp,
684LogicalResult MaterializeInDestinationOp::verify() {
685 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
686 return emitOpError(
"'dest' must be a tensor or a memref");
687 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
688 if (getOperation()->getNumResults() != 1)
689 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
690 if (destType != getResult().
getType())
691 return emitOpError(
"result and 'dest' types must match");
693 if (isa<BaseMemRefType>(getDest().
getType()) &&
694 getOperation()->getNumResults() != 0)
695 return emitOpError(
"memref 'dest' implies zero results");
696 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
697 return emitOpError(
"'restrict' is valid only for memref destinations");
698 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
699 return emitOpError(
"'writable' must be specified if and only if the "
700 "destination is of memref type");
702 ShapedType destType = cast<ShapedType>(getDest().
getType());
703 if (srcType.
hasRank() != destType.hasRank())
704 return emitOpError(
"source/destination shapes are incompatible");
709 for (
auto [src, dest] :
710 llvm::zip(srcType.
getShape(), destType.getShape())) {
711 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
717 return emitOpError(
"source/destination shapes are incompatible");
723void MaterializeInDestinationOp::build(
OpBuilder &builder,
726 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
727 build(builder, state, destTensorType ? destTensorType :
Type(),
731bool MaterializeInDestinationOp::isWritable(
Value value,
733 return isa<TensorType>(getDest().
getType()) ?
true : getWritable();
737 return getDestMutable();
740void MaterializeInDestinationOp::getEffects(
743 if (isa<BaseMemRefType>(getDest().
getType()))
753 return getWritable();
757 if (
auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
760 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
761 toBuffer->getNextNode() == this->getOperation())
762 return toBuffer.getTensor();
768 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
770 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
771 PatternRewriter &rewriter)
const override {
772 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
773 if (!memrefToTensorOp)
777 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
785 results.
add<DimOfToTensorFolder>(context);
793 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
794 if (memrefToTensor.getBuffer().getType() ==
getType())
795 return memrefToTensor.getBuffer();
803 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
805 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
806 PatternRewriter &rewriter)
const final {
807 auto tensorCastOperand =
808 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
809 if (!tensorCastOperand)
811 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
812 tensorCastOperand.getOperand().getType());
815 auto currentOutputMemRefType =
816 dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
817 if (!currentOutputMemRefType)
820 auto memrefType = currentOutputMemRefType.cloneWith(
821 srcTensorType.getShape(), srcTensorType.getElementType());
822 Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
823 tensorCastOperand.getOperand(),
824 toBuffer.getReadOnly());
834 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
836 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
837 PatternRewriter &rewriter)
const final {
847 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
849 LogicalResult matchAndRewrite(memref::LoadOp
load,
850 PatternRewriter &rewriter)
const override {
851 auto toBuffer =
load.getMemref().getDefiningOp<ToBufferOp>();
852 if (!toBuffer || !toBuffer.getReadOnly())
863 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
865 LogicalResult matchAndRewrite(memref::DimOp dimOp,
866 PatternRewriter &rewriter)
const override {
867 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
870 Value newSource = castOp.getOperand();
881 results.
add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
882 ToBufferToTensorFolding>(context);
885LogicalResult ToBufferOp::bufferize(
RewriterBase &rewriter,
887 BufferizationState &state) {
895std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
897 return memref::DeallocOp::create(builder, alloc.
getLoc(), alloc)
901std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
902 return CloneOp::create(builder, alloc.
getLoc(), alloc).getResult();
909LogicalResult DeallocOp::inferReturnTypes(
910 MLIRContext *context, std::optional<::mlir::Location> location,
913 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
915 IntegerType::get(context, 1));
919LogicalResult DeallocOp::verify() {
920 if (getMemrefs().size() != getConditions().size())
922 "must have the same number of conditions as memrefs to deallocate");
923 if (getRetained().size() != getUpdatedConditions().size())
924 return emitOpError(
"must have the same number of updated conditions "
925 "(results) as retained operands");
933 if (deallocOp.getMemrefs() == memrefs &&
934 deallocOp.getConditions() == conditions)
938 deallocOp.getMemrefsMutable().assign(memrefs);
939 deallocOp.getConditionsMutable().assign(conditions);
959struct DeallocRemoveDuplicateDeallocMemrefs
961 using OpRewritePattern<DeallocOp>::OpRewritePattern;
963 LogicalResult matchAndRewrite(DeallocOp deallocOp,
964 PatternRewriter &rewriter)
const override {
967 SmallVector<Value> newMemrefs, newConditions;
968 for (
auto [i, memref, cond] :
969 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
970 if (memrefToCondition.count(memref)) {
973 Value &newCond = newConditions[memrefToCondition[memref]];
976 arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
978 memrefToCondition.insert({memref, newConditions.size()});
979 newMemrefs.push_back(memref);
980 newConditions.push_back(cond);
1001struct DeallocRemoveDuplicateRetainedMemrefs
1003 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1005 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1006 PatternRewriter &rewriter)
const override {
1009 SmallVector<Value> newRetained;
1010 SmallVector<unsigned> resultReplacementIdx;
1012 for (
auto retained : deallocOp.getRetained()) {
1013 if (seen.count(retained)) {
1014 resultReplacementIdx.push_back(seen[retained]);
1019 newRetained.push_back(retained);
1020 resultReplacementIdx.push_back(i++);
1025 if (newRetained.size() == deallocOp.getRetained().size())
1031 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1032 deallocOp.getConditions(), newRetained);
1033 SmallVector<Value> replacements(
1034 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1035 return newDeallocOp.getUpdatedConditions()[idx];
1037 rewriter.
replaceOp(deallocOp, replacements);
1048 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1050 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1051 PatternRewriter &rewriter)
const override {
1052 if (deallocOp.getMemrefs().empty()) {
1053 Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1056 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1077 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1079 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1080 PatternRewriter &rewriter)
const override {
1081 SmallVector<Value> newMemrefs, newConditions;
1082 for (
auto [memref, cond] :
1083 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1085 newMemrefs.push_back(memref);
1086 newConditions.push_back(cond);
1114 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1116 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1117 PatternRewriter &rewriter)
const override {
1118 SmallVector<Value> newMemrefs(
1119 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1120 auto extractStridedOp =
1121 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1122 if (!extractStridedOp)
1124 Value allocMemref = extractStridedOp.getOperand();
1125 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1128 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1134 deallocOp.getConditions(), rewriter);
1152struct RemoveAllocDeallocPairWhenNoOtherUsers
1154 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1156 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1157 PatternRewriter &rewriter)
const override {
1158 SmallVector<Value> newMemrefs, newConditions;
1159 SmallVector<Operation *> toDelete;
1160 for (
auto [memref, cond] :
1161 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1162 if (
auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1166 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1168 memref.hasOneUse()) {
1169 toDelete.push_back(allocOp);
1174 newMemrefs.push_back(memref);
1175 newConditions.push_back(cond);
1182 for (Operation *op : toDelete)
1198 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1199 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1200 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1201 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1208#define GET_OP_CLASSES
1209#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
true
Given two iterators into the same block, return "true" if a is before `b.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.