29 auto srcType = llvm::cast<MemRefType>(value.
getType());
32 if (srcType.getElementType() != destType.getElementType())
34 if (srcType.getRank() != destType.getRank())
40 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41 int64_t sourceOffset, targetOffset;
43 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
44 failed(target.getStridesAndOffset(targetStrides, targetOffset)))
46 auto dynamicToStatic = [](int64_t a, int64_t b) {
47 return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
49 if (dynamicToStatic(sourceOffset, targetOffset))
51 for (
auto it : zip(sourceStrides, targetStrides))
52 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
60 if (memref::CastOp::areCastCompatible(srcType, destType) &&
61 isGuaranteedCastCompatible(srcType, destType)) {
68 for (
int i = 0; i < destType.getRank(); ++i) {
69 if (destType.getShape()[i] != ShapedType::kDynamic)
71 Value size = b.
create<memref::DimOp>(loc, value, i);
72 dynamicOperands.push_back(size);
75 FailureOr<Value>
copy =
76 options.createAlloc(b, loc, destType, dynamicOperands);
79 if (failed(
options.createMemCpy(b, loc, value, *
copy)))
89 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
93 Type srcType = bufferToTensor.getMemref().getType();
94 Type destType = toBuffer.getType();
97 if (srcType == destType) {
98 rewriter.
replaceOp(toBuffer, bufferToTensor.getMemref());
102 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
103 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
104 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 if (rankedSrcType && rankedDestType) {
109 rewriter, bufferToTensor.getMemref(), rankedDestType,
options);
110 if (failed(replacement))
113 rewriter.
replaceOp(toBuffer, *replacement);
119 if (unrankedSrcType && rankedDestType)
124 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125 "expected that types are cast compatible");
127 bufferToTensor.getMemref());
134 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
135 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
136 if (shapedType.isDynamicDim(i)) {
137 if (llvm::isa<MemRefType>(shapedType)) {
138 dynamicDims.push_back(b.
create<memref::DimOp>(loc, shapedValue, i));
140 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
141 dynamicDims.push_back(b.
create<tensor::DimOp>(loc, shapedValue, i));
151 LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
158 if (getOperation()->getUses().empty()) {
159 rewriter.
eraseOp(getOperation());
166 FailureOr<Value> maybeCopyBuffer =
168 if (failed(maybeCopyBuffer))
170 copyBuffer = *maybeCopyBuffer;
175 if (failed(allocType))
179 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
182 FailureOr<Value> alloc =
options.createAlloc(
183 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
189 if (failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
202 return static_cast<bool>(getCopy());
205 bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
208 "expected copy operand");
212 bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
215 "expected copy operand");
225 FailureOr<BaseMemRefType>
229 assert(value == getResult() &&
"invalid value");
233 if (getMemorySpace().has_value()) {
234 memorySpace = *getMemorySpace();
235 }
else if (getCopy()) {
237 state, invocationStack);
238 if (failed(copyBufferType))
240 memorySpace = copyBufferType->getMemorySpace();
244 return getOperation()->emitError(
"could not infer memory space");
252 return emitError(
"dynamic sizes not needed when copying a tensor");
255 <<
getType().getNumDynamicDims() <<
" dynamic sizes";
257 return emitError(
"expected that `copy` and return type match");
262 RankedTensorType type,
ValueRange dynamicSizes) {
263 build(builder, result, type, dynamicSizes,
Value(),
269 RankedTensorType type,
ValueRange dynamicSizes,
271 build(builder, result, type, dynamicSizes,
copy,
Value(),
277 IntegerAttr memorySpace) {
278 build(builder, result, type, dynamicSizes,
copy,
Value(),
297 LogicalResult matchAndRewrite(AllocTensorOp op,
303 unsigned int dynValCounter = 0;
304 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
305 if (!op.isDynamicDim(i))
307 Value value = op.getDynamicSizes()[dynValCounter++];
310 int64_t dim = intVal.getSExtValue();
312 newShape[i] = intVal.getSExtValue();
314 newDynamicSizes.push_back(value);
316 newDynamicSizes.push_back(value);
320 newShape, op.getType().getElementType(), op.getType().getEncoding());
321 if (newType == op.getType())
323 auto newOp = rewriter.
create<AllocTensorOp>(
324 op.getLoc(), newType, newDynamicSizes,
Value());
333 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
335 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
336 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
337 if (!allocTensorOp || !maybeConstantIndex)
339 if (*maybeConstantIndex < 0 ||
340 *maybeConstantIndex >= allocTensorOp.getType().getRank())
342 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
345 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
353 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
358 auto shapes = llvm::to_vector<4>(
359 llvm::map_range(llvm::seq<int64_t>(0,
getType().getRank()),
361 if (isDynamicDim(dim))
362 return getDynamicSize(builder, dim);
365 reifiedReturnShapes.emplace_back(std::move(shapes));
376 if (copyKeyword.succeeded())
382 if (sizeHintKeyword.succeeded())
396 if (copyKeyword.succeeded())
399 if (sizeHintKeyword.succeeded())
402 result.
addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
404 {static_cast<int32_t>(dynamicSizesOperands.size()),
405 static_cast<int32_t>(copyKeyword.succeeded()),
406 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
413 p <<
" copy(" << getCopy() <<
")";
415 p <<
" size_hint=" << getSizeHint();
417 AllocTensorOp::getOperandSegmentSizeAttr()});
419 auto type = getResult().getType();
420 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
427 assert(isDynamicDim(idx) &&
"expected dynamic dim");
429 return b.
create<tensor::DimOp>(getLoc(), getCopy(), idx);
430 return getOperand(getIndexOfDynamicSize(idx));
448 LogicalResult matchAndRewrite(CloneOp cloneOp,
450 if (cloneOp.use_empty()) {
455 Value source = cloneOp.getInput();
456 if (source.
getType() != cloneOp.getType() &&
457 !memref::CastOp::areCastCompatible({source.getType()},
458 {cloneOp.getType()}))
463 Value canonicalSource = source;
464 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
466 canonicalSource = iface.getViewSource();
468 std::optional<Operation *> maybeCloneDeallocOp =
471 if (!maybeCloneDeallocOp.has_value())
473 std::optional<Operation *> maybeSourceDeallocOp =
475 if (!maybeSourceDeallocOp.has_value())
477 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
478 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
482 if (cloneDeallocOp && sourceDeallocOp &&
486 Block *currentBlock = cloneOp->getBlock();
488 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
489 redundantDealloc = cloneDeallocOp;
490 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
491 redundantDealloc = sourceDeallocOp;
494 if (!redundantDealloc)
502 for (
Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
503 pos = pos->getNextNode()) {
507 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
508 if (!effectInterface)
514 if (source.
getType() != cloneOp.getType())
515 source = rewriter.
create<memref::CastOp>(cloneOp.getLoc(),
516 cloneOp.getType(), source);
518 rewriter.
eraseOp(redundantDealloc);
527 results.
add<SimplifyClones>(context);
534 LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
540 rewriter.
create<memref::DeallocOp>(getLoc(), *buffer);
541 rewriter.
eraseOp(getOperation());
549 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
551 return opOperand == getSourceMutable();
554 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
556 if (opOperand == getDestMutable()) {
557 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
563 bool MaterializeInDestinationOp::mustBufferizeInPlace(
572 MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
574 if (opOperand == getDestMutable()) {
575 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
582 MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
585 bool tensorDest = isa<TensorType>(getDest().
getType());
588 FailureOr<Value> maybeBuffer =
590 if (failed(maybeBuffer))
592 buffer = *maybeBuffer;
594 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
598 if (failed(srcBuffer))
600 if (failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
607 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
616 if (getOperation()->getNumResults() == 1) {
617 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
618 reifiedReturnShapes.resize(1,
620 reifiedReturnShapes[0] =
628 if (isa<TensorType>(getDest().
getType())) {
641 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
642 assert(getRestrict() &&
643 "expected that ops with memrefs dest have 'restrict'");
645 return builder.
create<ToTensorOp>(loc, getDest(),
true,
649 bool MaterializeInDestinationOp::isEquivalentSubset(
651 return equivalenceFn(getDest(), candidate);
655 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
659 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
660 return getOperation()->getOpOperand(0) ;
663 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
664 SubsetOpInterface subsetOp,
669 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
670 SubsetOpInterface subsetOp,
676 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
677 return emitOpError(
"'dest' must be a tensor or a memref");
678 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
679 if (getOperation()->getNumResults() != 1)
680 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
681 if (destType != getResult().
getType())
682 return emitOpError(
"result and 'dest' types must match");
684 if (isa<BaseMemRefType>(getDest().
getType()) &&
685 getOperation()->getNumResults() != 0)
686 return emitOpError(
"memref 'dest' implies zero results");
687 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
688 return emitOpError(
"'restrict' is valid only for memref destinations");
689 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
690 return emitOpError(
"'writable' must be specified if and only if the "
691 "destination is of memref type");
693 ShapedType destType = cast<ShapedType>(getDest().
getType());
694 if (srcType.
hasRank() != destType.hasRank())
695 return emitOpError(
"source/destination shapes are incompatible");
697 if (srcType.getRank() != destType.getRank())
698 return emitOpError(
"rank mismatch between source and destination shape");
699 for (
auto [src, dest] :
700 llvm::zip(srcType.
getShape(), destType.getShape())) {
701 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
707 return emitOpError(
"source/destination shapes are incompatible");
713 void MaterializeInDestinationOp::build(
OpBuilder &builder,
716 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
717 build(builder, state, destTensorType ? destTensorType :
Type(),
721 bool MaterializeInDestinationOp::isWritable(
Value value,
723 return isa<TensorType>(getDest().
getType()) ? true : getWritable();
727 return getDestMutable();
730 void MaterializeInDestinationOp::getEffects(
733 if (isa<BaseMemRefType>(getDest().
getType()))
743 return getWritable();
747 if (
auto toBuffer = getMemref().getDefiningOp<ToBufferOp>())
750 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
751 toBuffer->getNextNode() == this->getOperation())
752 return toBuffer.getTensor();
760 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
762 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
763 if (!memrefToTensorOp)
767 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
775 results.
add<DimOfToTensorFolder>(context);
783 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
784 if (memrefToTensor.getMemref().getType() ==
getType())
785 return memrefToTensor.getMemref();
795 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
797 auto tensorCastOperand =
798 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
799 if (!tensorCastOperand)
801 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
802 tensorCastOperand.getOperand().getType());
806 srcTensorType.getElementType());
807 Value memref = rewriter.
create<ToBufferOp>(toBuffer.getLoc(), memrefType,
808 tensorCastOperand.getOperand());
820 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
833 LogicalResult matchAndRewrite(memref::LoadOp load,
835 auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
849 LogicalResult matchAndRewrite(memref::DimOp dimOp,
851 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
854 Value newSource = castOp.getOperand();
865 results.
add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
866 ToBufferToTensorFolding>(context);
869 LogicalResult ToBufferOp::bufferize(
RewriterBase &rewriter,
879 std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
881 return builder.
create<memref::DeallocOp>(alloc.
getLoc(), alloc)
885 std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
886 return builder.
create<CloneOp>(alloc.
getLoc(), alloc).getResult();
893 LogicalResult DeallocOp::inferReturnTypes(
894 MLIRContext *context, std::optional<::mlir::Location> location,
897 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
904 if (getMemrefs().size() != getConditions().size())
906 "must have the same number of conditions as memrefs to deallocate");
907 if (getRetained().size() != getUpdatedConditions().size())
908 return emitOpError(
"must have the same number of updated conditions "
909 "(results) as retained operands");
917 if (deallocOp.getMemrefs() == memrefs &&
918 deallocOp.getConditions() == conditions)
922 deallocOp.getMemrefsMutable().assign(memrefs);
923 deallocOp.getConditionsMutable().assign(conditions);
943 struct DeallocRemoveDuplicateDeallocMemrefs
947 LogicalResult matchAndRewrite(DeallocOp deallocOp,
952 for (
auto [i, memref, cond] :
954 if (memrefToCondition.count(memref)) {
957 Value &newCond = newConditions[memrefToCondition[memref]];
960 rewriter.
create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
962 memrefToCondition.insert({memref, newConditions.size()});
963 newMemrefs.push_back(memref);
964 newConditions.push_back(cond);
985 struct DeallocRemoveDuplicateRetainedMemrefs
989 LogicalResult matchAndRewrite(DeallocOp deallocOp,
996 for (
auto retained : deallocOp.getRetained()) {
997 if (seen.count(retained)) {
998 resultReplacementIdx.push_back(seen[retained]);
1003 newRetained.push_back(retained);
1004 resultReplacementIdx.push_back(i++);
1009 if (newRetained.size() == deallocOp.getRetained().size())
1015 rewriter.
create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1016 deallocOp.getConditions(), newRetained);
1018 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1019 return newDeallocOp.getUpdatedConditions()[idx];
1021 rewriter.
replaceOp(deallocOp, replacements);
1034 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1036 if (deallocOp.getMemrefs().empty()) {
1037 Value constFalse = rewriter.
create<arith::ConstantOp>(
1063 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1066 for (
auto [memref, cond] :
1067 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1069 newMemrefs.push_back(memref);
1070 newConditions.push_back(cond);
1100 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1103 llvm::map_range(deallocOp.getMemrefs(), [&](
Value memref) {
1104 auto extractStridedOp =
1105 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1106 if (!extractStridedOp)
1108 Value allocMemref = extractStridedOp.getOperand();
1109 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1112 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1118 deallocOp.getConditions(), rewriter);
1136 struct RemoveAllocDeallocPairWhenNoOtherUsers
1140 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1144 for (
auto [memref, cond] :
1145 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1146 if (
auto allocOp = memref.
getDefiningOp<MemoryEffectOpInterface>()) {
1151 hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1153 toDelete.push_back(allocOp);
1158 newMemrefs.push_back(memref);
1159 newConditions.push_back(cond);
1182 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1183 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1184 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1185 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1192 #define GET_OP_CLASSES
1193 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
BufferizationState provides information about the state of the IR during the bufferization process.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Options for BufferizableOpInterface-based bufferization.