29 auto srcType = llvm::cast<MemRefType>(value.
getType());
32 if (srcType.getElementType() != destType.getElementType())
34 if (srcType.getRank() != destType.getRank())
40 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41 int64_t sourceOffset, targetOffset;
43 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
44 failed(target.getStridesAndOffset(targetStrides, targetOffset)))
46 auto dynamicToStatic = [](int64_t a, int64_t b) {
47 return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
49 if (dynamicToStatic(sourceOffset, targetOffset))
51 for (
auto it : zip(sourceStrides, targetStrides))
52 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
60 if (memref::CastOp::areCastCompatible(srcType, destType) &&
61 isGuaranteedCastCompatible(srcType, destType)) {
68 for (
int i = 0; i < destType.getRank(); ++i) {
69 if (destType.getShape()[i] != ShapedType::kDynamic)
71 Value size = b.
create<memref::DimOp>(loc, value, i);
72 dynamicOperands.push_back(size);
75 FailureOr<Value>
copy =
76 options.createAlloc(b, loc, destType, dynamicOperands);
79 if (failed(
options.createMemCpy(b, loc, value, *
copy)))
89 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
93 Type srcType = bufferToTensor.getBuffer().getType();
94 Type destType = toBuffer.getType();
97 if (srcType == destType) {
98 rewriter.
replaceOp(toBuffer, bufferToTensor.getBuffer());
102 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
103 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
104 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 if (rankedSrcType && rankedDestType) {
109 rewriter, bufferToTensor.getBuffer(), rankedDestType,
options);
110 if (failed(replacement))
113 rewriter.
replaceOp(toBuffer, *replacement);
119 if (unrankedSrcType && rankedDestType)
124 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125 "expected that types are cast compatible");
127 bufferToTensor.getBuffer());
134 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
135 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
136 if (shapedType.isDynamicDim(i)) {
137 if (llvm::isa<MemRefType>(shapedType)) {
138 dynamicDims.push_back(b.
create<memref::DimOp>(loc, shapedValue, i));
140 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
141 dynamicDims.push_back(b.
create<tensor::DimOp>(loc, shapedValue, i));
151 LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
158 if (getOperation()->getUses().empty()) {
159 rewriter.
eraseOp(getOperation());
166 FailureOr<Value> maybeCopyBuffer =
168 if (failed(maybeCopyBuffer))
170 copyBuffer = *maybeCopyBuffer;
175 if (failed(allocType))
179 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
182 FailureOr<Value> alloc =
options.createAlloc(
183 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
189 if (failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
202 return static_cast<bool>(getCopy());
205 bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
208 "expected copy operand");
212 bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
215 "expected copy operand");
225 FailureOr<BaseMemRefType>
229 assert(value == getResult() &&
"invalid value");
233 if (getMemorySpace().has_value()) {
234 memorySpace = *getMemorySpace();
235 }
else if (getCopy()) {
236 auto copyBufferType =
238 getCopy(),
options, state, invocationStack));
239 if (failed(copyBufferType))
241 memorySpace = copyBufferType->getMemorySpace();
245 return getOperation()->emitError(
"could not infer memory space");
253 return emitError(
"dynamic sizes not needed when copying a tensor");
256 <<
getType().getNumDynamicDims() <<
" dynamic sizes";
258 return emitError(
"expected that `copy` and return type match");
263 RankedTensorType type,
ValueRange dynamicSizes) {
264 build(builder, result, type, dynamicSizes,
Value(),
270 RankedTensorType type,
ValueRange dynamicSizes,
272 build(builder, result, type, dynamicSizes,
copy,
Value(),
278 IntegerAttr memorySpace) {
279 build(builder, result, type, dynamicSizes,
copy,
Value(),
298 LogicalResult matchAndRewrite(AllocTensorOp op,
304 unsigned int dynValCounter = 0;
305 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
306 if (!op.isDynamicDim(i))
308 Value value = op.getDynamicSizes()[dynValCounter++];
311 int64_t dim = intVal.getSExtValue();
313 newShape[i] = intVal.getSExtValue();
315 newDynamicSizes.push_back(value);
317 newDynamicSizes.push_back(value);
321 newShape, op.getType().getElementType(), op.getType().getEncoding());
322 if (newType == op.getType())
324 auto newOp = rewriter.
create<AllocTensorOp>(
325 op.getLoc(), newType, newDynamicSizes,
Value());
334 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
336 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
337 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
338 if (!allocTensorOp || !maybeConstantIndex)
340 if (*maybeConstantIndex < 0 ||
341 *maybeConstantIndex >= allocTensorOp.getType().getRank())
343 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
346 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
354 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
359 auto shapes = llvm::to_vector<4>(
360 llvm::map_range(llvm::seq<int64_t>(0,
getType().getRank()),
362 if (isDynamicDim(dim))
363 return getDynamicSize(builder, dim);
366 reifiedReturnShapes.emplace_back(std::move(shapes));
377 if (copyKeyword.succeeded())
383 if (sizeHintKeyword.succeeded())
397 if (copyKeyword.succeeded())
400 if (sizeHintKeyword.succeeded())
403 result.
addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
405 {static_cast<int32_t>(dynamicSizesOperands.size()),
406 static_cast<int32_t>(copyKeyword.succeeded()),
407 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
414 p <<
" copy(" << getCopy() <<
")";
416 p <<
" size_hint=" << getSizeHint();
418 AllocTensorOp::getOperandSegmentSizeAttr()});
420 auto type = getResult().getType();
421 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
428 assert(isDynamicDim(idx) &&
"expected dynamic dim");
430 return b.
create<tensor::DimOp>(getLoc(), getCopy(), idx);
431 return getOperand(getIndexOfDynamicSize(idx));
449 LogicalResult matchAndRewrite(CloneOp cloneOp,
451 if (cloneOp.use_empty()) {
456 Value source = cloneOp.getInput();
457 if (source.
getType() != cloneOp.getType() &&
458 !memref::CastOp::areCastCompatible({source.getType()},
459 {cloneOp.getType()}))
464 Value canonicalSource = source;
465 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
467 canonicalSource = iface.getViewSource();
469 std::optional<Operation *> maybeCloneDeallocOp =
472 if (!maybeCloneDeallocOp.has_value())
474 std::optional<Operation *> maybeSourceDeallocOp =
476 if (!maybeSourceDeallocOp.has_value())
478 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
479 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
483 if (cloneDeallocOp && sourceDeallocOp &&
487 Block *currentBlock = cloneOp->getBlock();
489 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
490 redundantDealloc = cloneDeallocOp;
491 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
492 redundantDealloc = sourceDeallocOp;
495 if (!redundantDealloc)
503 for (
Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
504 pos = pos->getNextNode()) {
508 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
509 if (!effectInterface)
515 if (source.
getType() != cloneOp.getType())
516 source = rewriter.
create<memref::CastOp>(cloneOp.getLoc(),
517 cloneOp.getType(), source);
519 rewriter.
eraseOp(redundantDealloc);
528 results.
add<SimplifyClones>(context);
535 LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
541 rewriter.
create<memref::DeallocOp>(getLoc(), *buffer);
542 rewriter.
eraseOp(getOperation());
550 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
552 return opOperand == getSourceMutable();
555 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
557 if (opOperand == getDestMutable()) {
558 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
564 bool MaterializeInDestinationOp::mustBufferizeInPlace(
573 MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
575 if (opOperand == getDestMutable()) {
576 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
583 MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
586 bool tensorDest = isa<TensorType>(getDest().
getType());
589 FailureOr<Value> maybeBuffer =
591 if (failed(maybeBuffer))
593 buffer = *maybeBuffer;
595 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
599 if (failed(srcBuffer))
601 if (failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
608 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
617 if (getOperation()->getNumResults() == 1) {
618 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
619 reifiedReturnShapes.resize(1,
621 reifiedReturnShapes[0] =
629 if (isa<TensorType>(getDest().
getType())) {
642 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
643 assert(getRestrict() &&
644 "expected that ops with memrefs dest have 'restrict'");
646 return builder.
create<ToTensorOp>(
648 true, getWritable());
651 bool MaterializeInDestinationOp::isEquivalentSubset(
653 return equivalenceFn(getDest(), candidate);
657 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
661 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
662 return getOperation()->getOpOperand(0) ;
665 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
666 SubsetOpInterface subsetOp,
671 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
672 SubsetOpInterface subsetOp,
678 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
679 return emitOpError(
"'dest' must be a tensor or a memref");
680 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
681 if (getOperation()->getNumResults() != 1)
682 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
683 if (destType != getResult().
getType())
684 return emitOpError(
"result and 'dest' types must match");
686 if (isa<BaseMemRefType>(getDest().
getType()) &&
687 getOperation()->getNumResults() != 0)
688 return emitOpError(
"memref 'dest' implies zero results");
689 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
690 return emitOpError(
"'restrict' is valid only for memref destinations");
691 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
692 return emitOpError(
"'writable' must be specified if and only if the "
693 "destination is of memref type");
695 ShapedType destType = cast<ShapedType>(getDest().
getType());
696 if (srcType.
hasRank() != destType.hasRank())
697 return emitOpError(
"source/destination shapes are incompatible");
699 if (srcType.getRank() != destType.getRank())
700 return emitOpError(
"rank mismatch between source and destination shape");
701 for (
auto [src, dest] :
702 llvm::zip(srcType.
getShape(), destType.getShape())) {
703 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
709 return emitOpError(
"source/destination shapes are incompatible");
715 void MaterializeInDestinationOp::build(
OpBuilder &builder,
718 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
719 build(builder, state, destTensorType ? destTensorType :
Type(),
723 bool MaterializeInDestinationOp::isWritable(
Value value,
725 return isa<TensorType>(getDest().
getType()) ? true : getWritable();
729 return getDestMutable();
732 void MaterializeInDestinationOp::getEffects(
735 if (isa<BaseMemRefType>(getDest().
getType()))
745 return getWritable();
749 if (
auto toBuffer =
getBuffer().getDefiningOp<ToBufferOp>())
752 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
753 toBuffer->getNextNode() == this->getOperation())
754 return toBuffer.getTensor();
762 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
764 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
765 if (!memrefToTensorOp)
769 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
777 results.
add<DimOfToTensorFolder>(context);
785 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
786 if (memrefToTensor.getBuffer().getType() ==
getType())
787 return memrefToTensor.getBuffer();
797 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
799 auto tensorCastOperand =
800 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
801 if (!tensorCastOperand)
803 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
804 tensorCastOperand.getOperand().getType());
808 srcTensorType.getElementType());
809 Value memref = rewriter.
create<ToBufferOp>(toBuffer.getLoc(), memrefType,
810 tensorCastOperand.getOperand());
822 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
835 LogicalResult matchAndRewrite(memref::LoadOp load,
837 auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
851 LogicalResult matchAndRewrite(memref::DimOp dimOp,
853 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
856 Value newSource = castOp.getOperand();
867 results.
add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
868 ToBufferToTensorFolding>(context);
871 LogicalResult ToBufferOp::bufferize(
RewriterBase &rewriter,
881 std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
883 return builder.
create<memref::DeallocOp>(alloc.
getLoc(), alloc)
887 std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
888 return builder.
create<CloneOp>(alloc.
getLoc(), alloc).getResult();
895 LogicalResult DeallocOp::inferReturnTypes(
896 MLIRContext *context, std::optional<::mlir::Location> location,
899 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
906 if (getMemrefs().size() != getConditions().size())
908 "must have the same number of conditions as memrefs to deallocate");
909 if (getRetained().size() != getUpdatedConditions().size())
910 return emitOpError(
"must have the same number of updated conditions "
911 "(results) as retained operands");
919 if (deallocOp.getMemrefs() == memrefs &&
920 deallocOp.getConditions() == conditions)
924 deallocOp.getMemrefsMutable().assign(memrefs);
925 deallocOp.getConditionsMutable().assign(conditions);
945 struct DeallocRemoveDuplicateDeallocMemrefs
949 LogicalResult matchAndRewrite(DeallocOp deallocOp,
954 for (
auto [i, memref, cond] :
956 if (memrefToCondition.count(memref)) {
959 Value &newCond = newConditions[memrefToCondition[memref]];
962 rewriter.
create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
964 memrefToCondition.insert({memref, newConditions.size()});
965 newMemrefs.push_back(memref);
966 newConditions.push_back(cond);
987 struct DeallocRemoveDuplicateRetainedMemrefs
991 LogicalResult matchAndRewrite(DeallocOp deallocOp,
998 for (
auto retained : deallocOp.getRetained()) {
999 if (seen.count(retained)) {
1000 resultReplacementIdx.push_back(seen[retained]);
1005 newRetained.push_back(retained);
1006 resultReplacementIdx.push_back(i++);
1011 if (newRetained.size() == deallocOp.getRetained().size())
1017 rewriter.
create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1018 deallocOp.getConditions(), newRetained);
1020 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1021 return newDeallocOp.getUpdatedConditions()[idx];
1023 rewriter.
replaceOp(deallocOp, replacements);
1036 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1038 if (deallocOp.getMemrefs().empty()) {
1039 Value constFalse = rewriter.
create<arith::ConstantOp>(
1065 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1068 for (
auto [memref, cond] :
1069 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1071 newMemrefs.push_back(memref);
1072 newConditions.push_back(cond);
1102 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1105 llvm::map_range(deallocOp.getMemrefs(), [&](
Value memref) {
1106 auto extractStridedOp =
1107 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1108 if (!extractStridedOp)
1110 Value allocMemref = extractStridedOp.getOperand();
1111 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1114 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1120 deallocOp.getConditions(), rewriter);
1138 struct RemoveAllocDeallocPairWhenNoOtherUsers
1142 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1146 for (
auto [memref, cond] :
1147 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1148 if (
auto allocOp = memref.
getDefiningOp<MemoryEffectOpInterface>()) {
1153 hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1155 toDelete.push_back(allocOp);
1160 newMemrefs.push_back(memref);
1161 newConditions.push_back(cond);
1184 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1185 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1186 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1187 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1194 #define GET_OP_CLASSES
1195 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Options for BufferizableOpInterface-based bufferization.