29 auto srcType = llvm::cast<MemRefType>(value.
getType());
32 if (srcType.getElementType() != destType.getElementType())
34 if (srcType.getMemorySpace() != destType.getMemorySpace())
36 if (srcType.getRank() != destType.getRank())
42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43 int64_t sourceOffset, targetOffset;
48 auto dynamicToStatic = [](int64_t a, int64_t b) {
49 return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
51 if (dynamicToStatic(sourceOffset, targetOffset))
53 for (
auto it : zip(sourceStrides, targetStrides))
54 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
62 if (memref::CastOp::areCastCompatible(srcType, destType) &&
63 isGuaranteedCastCompatible(srcType, destType)) {
70 for (
int i = 0; i < destType.getRank(); ++i) {
71 if (destType.getShape()[i] != ShapedType::kDynamic)
73 Value size = b.
create<memref::DimOp>(loc, value, i);
74 dynamicOperands.push_back(size);
77 FailureOr<Value>
copy =
78 options.createAlloc(b, loc, destType, dynamicOperands);
81 if (failed(
options.createMemCpy(b, loc, value, *
copy)))
91 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
95 Type srcType = memrefToTensor.getMemref().getType();
96 Type destType = toMemref.getType();
99 if (srcType == destType) {
100 rewriter.
replaceOp(toMemref, memrefToTensor.getMemref());
104 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
109 if (rankedSrcType && rankedDestType) {
111 rewriter, memrefToTensor.getMemref(), rankedDestType,
options);
112 if (failed(replacement))
115 rewriter.
replaceOp(toMemref, *replacement);
121 if (unrankedSrcType && rankedDestType)
126 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127 "expected that types are cast compatible");
129 memrefToTensor.getMemref());
136 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
137 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138 if (shapedType.isDynamicDim(i)) {
139 if (llvm::isa<MemRefType>(shapedType)) {
140 dynamicDims.push_back(b.
create<memref::DimOp>(loc, shapedValue, i));
142 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
143 dynamicDims.push_back(b.
create<tensor::DimOp>(loc, shapedValue, i));
153 LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
159 if (getOperation()->getUses().empty()) {
160 rewriter.
eraseOp(getOperation());
168 if (failed(maybeCopyBuffer))
170 copyBuffer = *maybeCopyBuffer;
175 if (failed(allocType))
179 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
182 FailureOr<Value> alloc =
options.createAlloc(
183 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
189 if (failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
202 return static_cast<bool>(getCopy());
205 bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
208 "expected copy operand");
212 bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
215 "expected copy operand");
225 FailureOr<BaseMemRefType>
228 assert(value == getResult() &&
"invalid value");
232 if (getMemorySpace().has_value()) {
233 memorySpace = *getMemorySpace();
234 }
else if (getCopy()) {
235 auto copyBufferType =
237 if (failed(copyBufferType))
239 memorySpace = copyBufferType->getMemorySpace();
243 return getOperation()->emitError(
"could not infer memory space");
251 return emitError(
"dynamic sizes not needed when copying a tensor");
254 <<
getType().getNumDynamicDims() <<
" dynamic sizes";
256 return emitError(
"expected that `copy` and return type match");
261 RankedTensorType type,
ValueRange dynamicSizes) {
262 build(builder, result, type, dynamicSizes,
Value(),
268 RankedTensorType type,
ValueRange dynamicSizes,
270 build(builder, result, type, dynamicSizes,
copy,
Value(),
276 IntegerAttr memorySpace) {
277 build(builder, result, type, dynamicSizes,
copy,
Value(),
296 LogicalResult matchAndRewrite(AllocTensorOp op,
302 unsigned int dynValCounter = 0;
303 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
304 if (!op.isDynamicDim(i))
306 Value value = op.getDynamicSizes()[dynValCounter++];
309 int64_t dim = intVal.getSExtValue();
311 newShape[i] = intVal.getSExtValue();
313 newDynamicSizes.push_back(value);
315 newDynamicSizes.push_back(value);
319 newShape, op.getType().getElementType(), op.getType().getEncoding());
320 if (newType == op.getType())
322 auto newOp = rewriter.
create<AllocTensorOp>(
323 op.getLoc(), newType, newDynamicSizes,
Value());
332 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
334 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
335 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336 if (!allocTensorOp || !maybeConstantIndex)
338 if (*maybeConstantIndex < 0 ||
339 *maybeConstantIndex >= allocTensorOp.getType().getRank())
341 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
344 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
352 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
357 auto shapes = llvm::to_vector<4>(
358 llvm::map_range(llvm::seq<int64_t>(0,
getType().getRank()),
360 if (isDynamicDim(dim))
361 return getDynamicSize(builder, dim);
364 reifiedReturnShapes.emplace_back(std::move(shapes));
375 if (copyKeyword.succeeded())
381 if (sizeHintKeyword.succeeded())
395 if (copyKeyword.succeeded())
398 if (sizeHintKeyword.succeeded())
401 result.
addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
403 {static_cast<int32_t>(dynamicSizesOperands.size()),
404 static_cast<int32_t>(copyKeyword.succeeded()),
405 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
412 p <<
" copy(" << getCopy() <<
")";
414 p <<
" size_hint=" << getSizeHint();
416 AllocTensorOp::getOperandSegmentSizeAttr()});
418 auto type = getResult().getType();
419 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
426 assert(isDynamicDim(idx) &&
"expected dynamic dim");
428 return b.
create<tensor::DimOp>(getLoc(), getCopy(), idx);
429 return getOperand(getIndexOfDynamicSize(idx));
447 LogicalResult matchAndRewrite(CloneOp cloneOp,
449 if (cloneOp.use_empty()) {
454 Value source = cloneOp.getInput();
455 if (source.
getType() != cloneOp.getType() &&
456 !memref::CastOp::areCastCompatible({source.getType()},
457 {cloneOp.getType()}))
462 Value canonicalSource = source;
463 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
465 canonicalSource = iface.getViewSource();
467 std::optional<Operation *> maybeCloneDeallocOp =
470 if (!maybeCloneDeallocOp.has_value())
472 std::optional<Operation *> maybeSourceDeallocOp =
474 if (!maybeSourceDeallocOp.has_value())
476 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
477 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
481 if (cloneDeallocOp && sourceDeallocOp &&
485 Block *currentBlock = cloneOp->getBlock();
487 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
488 redundantDealloc = cloneDeallocOp;
489 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
490 redundantDealloc = sourceDeallocOp;
493 if (!redundantDealloc)
501 for (
Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
502 pos = pos->getNextNode()) {
506 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
507 if (!effectInterface)
513 if (source.
getType() != cloneOp.getType())
514 source = rewriter.
create<memref::CastOp>(cloneOp.getLoc(),
515 cloneOp.getType(), source);
517 rewriter.
eraseOp(redundantDealloc);
526 results.
add<SimplifyClones>(context);
533 LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
538 rewriter.
create<memref::DeallocOp>(getLoc(), *buffer);
539 rewriter.
eraseOp(getOperation());
547 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
549 return opOperand == getSourceMutable();
552 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
554 if (opOperand == getDestMutable()) {
555 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
561 bool MaterializeInDestinationOp::mustBufferizeInPlace(
570 MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
572 if (opOperand == getDestMutable()) {
573 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
580 MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
582 bool tensorDest = isa<TensorType>(getDest().
getType());
586 if (failed(maybeBuffer))
588 buffer = *maybeBuffer;
590 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
594 if (failed(srcBuffer))
596 if (failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
603 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
612 if (getOperation()->getNumResults() == 1) {
613 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
614 reifiedReturnShapes.resize(1,
616 reifiedReturnShapes[0] =
622 Value MaterializeInDestinationOp::buildSubsetExtraction(
OpBuilder &builder,
624 if (isa<TensorType>(getDest().
getType())) {
637 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
638 assert(getRestrict() &&
639 "expected that ops with memrefs dest have 'restrict'");
641 return builder.
create<ToTensorOp>(loc, getDest(),
true,
645 bool MaterializeInDestinationOp::isEquivalentSubset(
647 return equivalenceFn(getDest(), candidate);
651 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
655 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
656 return getOperation()->getOpOperand(0) ;
659 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
660 SubsetOpInterface subsetOp,
665 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
666 SubsetOpInterface subsetOp,
672 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
673 return emitOpError(
"'dest' must be a tensor or a memref");
674 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
675 if (getOperation()->getNumResults() != 1)
676 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
677 if (destType != getResult().
getType())
678 return emitOpError(
"result and 'dest' types must match");
680 if (isa<BaseMemRefType>(getDest().
getType()) &&
681 getOperation()->getNumResults() != 0)
682 return emitOpError(
"memref 'dest' implies zero results");
683 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
684 return emitOpError(
"'restrict' is valid only for memref destinations");
685 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
686 return emitOpError(
"'writable' must be specified if and only if the "
687 "destination is of memref type");
689 ShapedType destType = cast<ShapedType>(getDest().
getType());
690 if (srcType.
hasRank() != destType.hasRank())
691 return emitOpError(
"source/destination shapes are incompatible");
693 if (srcType.getRank() != destType.getRank())
694 return emitOpError(
"rank mismatch between source and destination shape");
695 for (
auto [src, dest] :
696 llvm::zip(srcType.
getShape(), destType.getShape())) {
697 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
703 return emitOpError(
"source/destination shapes are incompatible");
709 void MaterializeInDestinationOp::build(
OpBuilder &builder,
712 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
713 build(builder, state, destTensorType ? destTensorType :
Type(),
717 bool MaterializeInDestinationOp::isWritable(
Value value,
719 return isa<TensorType>(getDest().
getType()) ? true : getWritable();
723 return getDestMutable();
726 void MaterializeInDestinationOp::getEffects(
729 if (isa<BaseMemRefType>(getDest().
getType()))
739 return getWritable();
743 if (
auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
746 if (toMemref->getBlock() == this->getOperation()->getBlock() &&
747 toMemref->getNextNode() == this->getOperation())
748 return toMemref.getTensor();
756 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
758 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
759 if (!memrefToTensorOp)
763 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
771 results.
add<DimOfToTensorFolder>(context);
779 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
780 if (memrefToTensor.getMemref().getType() ==
getType())
781 return memrefToTensor.getMemref();
791 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
793 auto tensorCastOperand =
794 toMemref.getOperand().getDefiningOp<tensor::CastOp>();
795 if (!tensorCastOperand)
797 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
798 tensorCastOperand.getOperand().getType());
802 srcTensorType.getElementType());
803 Value memref = rewriter.
create<ToMemrefOp>(toMemref.getLoc(), memrefType,
804 tensorCastOperand.getOperand());
816 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
829 LogicalResult matchAndRewrite(memref::LoadOp load,
831 auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
845 LogicalResult matchAndRewrite(memref::DimOp dimOp,
847 auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
850 Value newSource = castOp.getOperand();
861 results.
add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
862 ToMemrefToTensorFolding>(context);
865 LogicalResult ToMemrefOp::bufferize(
RewriterBase &rewriter,
874 std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
876 return builder.
create<memref::DeallocOp>(alloc.
getLoc(), alloc)
880 std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
881 return builder.
create<CloneOp>(alloc.
getLoc(), alloc).getResult();
888 LogicalResult DeallocOp::inferReturnTypes(
889 MLIRContext *context, std::optional<::mlir::Location> location,
892 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
899 if (getMemrefs().size() != getConditions().size())
901 "must have the same number of conditions as memrefs to deallocate");
902 if (getRetained().size() != getUpdatedConditions().size())
903 return emitOpError(
"must have the same number of updated conditions "
904 "(results) as retained operands");
912 if (deallocOp.getMemrefs() == memrefs &&
913 deallocOp.getConditions() == conditions)
917 deallocOp.getMemrefsMutable().assign(memrefs);
918 deallocOp.getConditionsMutable().assign(conditions);
938 struct DeallocRemoveDuplicateDeallocMemrefs
942 LogicalResult matchAndRewrite(DeallocOp deallocOp,
947 for (
auto [i, memref, cond] :
949 if (memrefToCondition.count(memref)) {
952 Value &newCond = newConditions[memrefToCondition[memref]];
955 rewriter.
create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
957 memrefToCondition.insert({memref, newConditions.size()});
958 newMemrefs.push_back(memref);
959 newConditions.push_back(cond);
980 struct DeallocRemoveDuplicateRetainedMemrefs
984 LogicalResult matchAndRewrite(DeallocOp deallocOp,
991 for (
auto retained : deallocOp.getRetained()) {
992 if (seen.count(retained)) {
993 resultReplacementIdx.push_back(seen[retained]);
998 newRetained.push_back(retained);
999 resultReplacementIdx.push_back(i++);
1004 if (newRetained.size() == deallocOp.getRetained().size())
1010 rewriter.
create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1011 deallocOp.getConditions(), newRetained);
1013 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1014 return newDeallocOp.getUpdatedConditions()[idx];
1016 rewriter.
replaceOp(deallocOp, replacements);
1029 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1031 if (deallocOp.getMemrefs().empty()) {
1032 Value constFalse = rewriter.
create<arith::ConstantOp>(
1058 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1061 for (
auto [memref, cond] :
1062 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1064 newMemrefs.push_back(memref);
1065 newConditions.push_back(cond);
1095 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1098 llvm::map_range(deallocOp.getMemrefs(), [&](
Value memref) {
1099 auto extractStridedOp =
1100 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1101 if (!extractStridedOp)
1103 Value allocMemref = extractStridedOp.getOperand();
1104 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1107 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1113 deallocOp.getConditions(), rewriter);
1131 struct RemoveAllocDeallocPairWhenNoOtherUsers
1135 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1139 for (
auto [memref, cond] :
1140 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1141 if (
auto allocOp = memref.
getDefiningOp<MemoryEffectOpInterface>()) {
1146 hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1148 toDelete.push_back(allocOp);
1153 newMemrefs.push_back(memref);
1154 newConditions.push_back(cond);
1177 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1178 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1179 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1180 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1187 #define GET_OP_CLASSES
1188 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Options for BufferizableOpInterface-based bufferization.