29 auto srcType = llvm::cast<MemRefType>(value.
getType());
32 if (srcType.getElementType() != destType.getElementType())
34 if (srcType.getRank() != destType.getRank())
40 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41 int64_t sourceOffset, targetOffset;
43 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
44 failed(target.getStridesAndOffset(targetStrides, targetOffset)))
46 auto dynamicToStatic = [](int64_t a, int64_t b) {
47 return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
49 if (dynamicToStatic(sourceOffset, targetOffset))
51 for (
auto it : zip(sourceStrides, targetStrides))
52 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
60 if (memref::CastOp::areCastCompatible(srcType, destType) &&
61 isGuaranteedCastCompatible(srcType, destType)) {
68 for (
int i = 0; i < destType.getRank(); ++i) {
69 if (destType.getShape()[i] != ShapedType::kDynamic)
71 Value size = b.
create<memref::DimOp>(loc, value, i);
72 dynamicOperands.push_back(size);
75 FailureOr<Value>
copy =
76 options.createAlloc(b, loc, destType, dynamicOperands);
79 if (failed(
options.createMemCpy(b, loc, value, *
copy)))
89 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
93 Type srcType = memrefToTensor.getMemref().getType();
94 Type destType = toMemref.getType();
97 if (srcType == destType) {
98 rewriter.
replaceOp(toMemref, memrefToTensor.getMemref());
102 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
103 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
104 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 if (rankedSrcType && rankedDestType) {
109 rewriter, memrefToTensor.getMemref(), rankedDestType,
options);
110 if (failed(replacement))
113 rewriter.
replaceOp(toMemref, *replacement);
119 if (unrankedSrcType && rankedDestType)
124 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125 "expected that types are cast compatible");
127 memrefToTensor.getMemref());
134 auto shapedType = llvm::cast<ShapedType>(shapedValue.
getType());
135 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
136 if (shapedType.isDynamicDim(i)) {
137 if (llvm::isa<MemRefType>(shapedType)) {
138 dynamicDims.push_back(b.
create<memref::DimOp>(loc, shapedValue, i));
140 assert(llvm::isa<RankedTensorType>(shapedType) &&
"expected tensor");
141 dynamicDims.push_back(b.
create<tensor::DimOp>(loc, shapedValue, i));
151 LogicalResult AllocTensorOp::bufferize(
RewriterBase &rewriter,
157 if (getOperation()->getUses().empty()) {
158 rewriter.
eraseOp(getOperation());
166 if (failed(maybeCopyBuffer))
168 copyBuffer = *maybeCopyBuffer;
173 if (failed(allocType))
177 assert(dynamicDims.empty() &&
"expected either `copy` or `dynamicDims`");
180 FailureOr<Value> alloc =
options.createAlloc(
181 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
187 if (failed(
options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
197 bool AllocTensorOp::resultBufferizesToMemoryWrite(
OpResult opResult,
200 return static_cast<bool>(getCopy());
203 bool AllocTensorOp::bufferizesToMemoryRead(
OpOperand &opOperand,
206 "expected copy operand");
210 bool AllocTensorOp::bufferizesToMemoryWrite(
OpOperand &opOperand,
213 "expected copy operand");
223 FailureOr<BaseMemRefType>
226 assert(value == getResult() &&
"invalid value");
230 if (getMemorySpace().has_value()) {
231 memorySpace = *getMemorySpace();
232 }
else if (getCopy()) {
233 auto copyBufferType =
235 if (failed(copyBufferType))
237 memorySpace = copyBufferType->getMemorySpace();
241 return getOperation()->emitError(
"could not infer memory space");
249 return emitError(
"dynamic sizes not needed when copying a tensor");
252 <<
getType().getNumDynamicDims() <<
" dynamic sizes";
254 return emitError(
"expected that `copy` and return type match");
259 RankedTensorType type,
ValueRange dynamicSizes) {
260 build(builder, result, type, dynamicSizes,
Value(),
266 RankedTensorType type,
ValueRange dynamicSizes,
268 build(builder, result, type, dynamicSizes,
copy,
Value(),
274 IntegerAttr memorySpace) {
275 build(builder, result, type, dynamicSizes,
copy,
Value(),
294 LogicalResult matchAndRewrite(AllocTensorOp op,
300 unsigned int dynValCounter = 0;
301 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
302 if (!op.isDynamicDim(i))
304 Value value = op.getDynamicSizes()[dynValCounter++];
307 int64_t dim = intVal.getSExtValue();
309 newShape[i] = intVal.getSExtValue();
311 newDynamicSizes.push_back(value);
313 newDynamicSizes.push_back(value);
317 newShape, op.getType().getElementType(), op.getType().getEncoding());
318 if (newType == op.getType())
320 auto newOp = rewriter.
create<AllocTensorOp>(
321 op.getLoc(), newType, newDynamicSizes,
Value());
330 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
332 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
333 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
334 if (!allocTensorOp || !maybeConstantIndex)
336 if (*maybeConstantIndex < 0 ||
337 *maybeConstantIndex >= allocTensorOp.getType().getRank())
339 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
342 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
350 results.
add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
355 auto shapes = llvm::to_vector<4>(
356 llvm::map_range(llvm::seq<int64_t>(0,
getType().getRank()),
358 if (isDynamicDim(dim))
359 return getDynamicSize(builder, dim);
362 reifiedReturnShapes.emplace_back(std::move(shapes));
373 if (copyKeyword.succeeded())
379 if (sizeHintKeyword.succeeded())
393 if (copyKeyword.succeeded())
396 if (sizeHintKeyword.succeeded())
399 result.
addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
401 {static_cast<int32_t>(dynamicSizesOperands.size()),
402 static_cast<int32_t>(copyKeyword.succeeded()),
403 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
410 p <<
" copy(" << getCopy() <<
")";
412 p <<
" size_hint=" << getSizeHint();
414 AllocTensorOp::getOperandSegmentSizeAttr()});
416 auto type = getResult().getType();
417 if (
auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
424 assert(isDynamicDim(idx) &&
"expected dynamic dim");
426 return b.
create<tensor::DimOp>(getLoc(), getCopy(), idx);
427 return getOperand(getIndexOfDynamicSize(idx));
445 LogicalResult matchAndRewrite(CloneOp cloneOp,
447 if (cloneOp.use_empty()) {
452 Value source = cloneOp.getInput();
453 if (source.
getType() != cloneOp.getType() &&
454 !memref::CastOp::areCastCompatible({source.getType()},
455 {cloneOp.getType()}))
460 Value canonicalSource = source;
461 while (
auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
463 canonicalSource = iface.getViewSource();
465 std::optional<Operation *> maybeCloneDeallocOp =
468 if (!maybeCloneDeallocOp.has_value())
470 std::optional<Operation *> maybeSourceDeallocOp =
472 if (!maybeSourceDeallocOp.has_value())
474 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
475 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
479 if (cloneDeallocOp && sourceDeallocOp &&
483 Block *currentBlock = cloneOp->getBlock();
485 if (cloneDeallocOp && cloneDeallocOp->
getBlock() == currentBlock) {
486 redundantDealloc = cloneDeallocOp;
487 }
else if (sourceDeallocOp && sourceDeallocOp->
getBlock() == currentBlock) {
488 redundantDealloc = sourceDeallocOp;
491 if (!redundantDealloc)
499 for (
Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
500 pos = pos->getNextNode()) {
504 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
505 if (!effectInterface)
511 if (source.
getType() != cloneOp.getType())
512 source = rewriter.
create<memref::CastOp>(cloneOp.getLoc(),
513 cloneOp.getType(), source);
515 rewriter.
eraseOp(redundantDealloc);
524 results.
add<SimplifyClones>(context);
531 LogicalResult DeallocTensorOp::bufferize(
RewriterBase &rewriter,
536 rewriter.
create<memref::DeallocOp>(getLoc(), *buffer);
537 rewriter.
eraseOp(getOperation());
545 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
547 return opOperand == getSourceMutable();
550 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
552 if (opOperand == getDestMutable()) {
553 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
559 bool MaterializeInDestinationOp::mustBufferizeInPlace(
568 MaterializeInDestinationOp::getAliasingValues(
OpOperand &opOperand,
570 if (opOperand == getDestMutable()) {
571 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
578 MaterializeInDestinationOp::bufferize(
RewriterBase &rewriter,
580 bool tensorDest = isa<TensorType>(getDest().
getType());
584 if (failed(maybeBuffer))
586 buffer = *maybeBuffer;
588 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
592 if (failed(srcBuffer))
594 if (failed(
options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
601 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
610 if (getOperation()->getNumResults() == 1) {
611 assert(isa<TensorType>(getDest().
getType()) &&
"expected tensor type");
612 reifiedReturnShapes.resize(1,
614 reifiedReturnShapes[0] =
622 if (isa<TensorType>(getDest().
getType())) {
635 assert(isa<BaseMemRefType>(getDest().
getType()) &&
"expected memref type");
636 assert(getRestrict() &&
637 "expected that ops with memrefs dest have 'restrict'");
639 return builder.
create<ToTensorOp>(loc, getDest(),
true,
643 bool MaterializeInDestinationOp::isEquivalentSubset(
645 return equivalenceFn(getDest(), candidate);
649 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
653 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
654 return getOperation()->getOpOperand(0) ;
657 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
658 SubsetOpInterface subsetOp,
663 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
664 SubsetOpInterface subsetOp,
670 if (!isa<TensorType, BaseMemRefType>(getDest().
getType()))
671 return emitOpError(
"'dest' must be a tensor or a memref");
672 if (
auto destType = dyn_cast<TensorType>(getDest().
getType())) {
673 if (getOperation()->getNumResults() != 1)
674 return emitOpError(
"tensor 'dest' implies exactly one tensor result");
675 if (destType != getResult().
getType())
676 return emitOpError(
"result and 'dest' types must match");
678 if (isa<BaseMemRefType>(getDest().
getType()) &&
679 getOperation()->getNumResults() != 0)
680 return emitOpError(
"memref 'dest' implies zero results");
681 if (getRestrict() && !isa<BaseMemRefType>(getDest().
getType()))
682 return emitOpError(
"'restrict' is valid only for memref destinations");
683 if (getWritable() != isa<BaseMemRefType>(getDest().
getType()))
684 return emitOpError(
"'writable' must be specified if and only if the "
685 "destination is of memref type");
687 ShapedType destType = cast<ShapedType>(getDest().
getType());
688 if (srcType.
hasRank() != destType.hasRank())
689 return emitOpError(
"source/destination shapes are incompatible");
691 if (srcType.getRank() != destType.getRank())
692 return emitOpError(
"rank mismatch between source and destination shape");
693 for (
auto [src, dest] :
694 llvm::zip(srcType.
getShape(), destType.getShape())) {
695 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
701 return emitOpError(
"source/destination shapes are incompatible");
707 void MaterializeInDestinationOp::build(
OpBuilder &builder,
710 auto destTensorType = dyn_cast<TensorType>(dest.
getType());
711 build(builder, state, destTensorType ? destTensorType :
Type(),
715 bool MaterializeInDestinationOp::isWritable(
Value value,
717 return isa<TensorType>(getDest().
getType()) ? true : getWritable();
721 return getDestMutable();
724 void MaterializeInDestinationOp::getEffects(
727 if (isa<BaseMemRefType>(getDest().
getType()))
737 return getWritable();
741 if (
auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
744 if (toMemref->getBlock() == this->getOperation()->getBlock() &&
745 toMemref->getNextNode() == this->getOperation())
746 return toMemref.getTensor();
754 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
756 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
757 if (!memrefToTensorOp)
761 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
769 results.
add<DimOfToTensorFolder>(context);
777 if (
auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
778 if (memrefToTensor.getMemref().getType() ==
getType())
779 return memrefToTensor.getMemref();
789 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
791 auto tensorCastOperand =
792 toMemref.getOperand().getDefiningOp<tensor::CastOp>();
793 if (!tensorCastOperand)
795 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
796 tensorCastOperand.getOperand().getType());
800 srcTensorType.getElementType());
801 Value memref = rewriter.
create<ToMemrefOp>(toMemref.getLoc(), memrefType,
802 tensorCastOperand.getOperand());
814 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
827 LogicalResult matchAndRewrite(memref::LoadOp load,
829 auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
843 LogicalResult matchAndRewrite(memref::DimOp dimOp,
845 auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
848 Value newSource = castOp.getOperand();
859 results.
add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
860 ToMemrefToTensorFolding>(context);
863 LogicalResult ToMemrefOp::bufferize(
RewriterBase &rewriter,
872 std::optional<Operation *> CloneOp::buildDealloc(
OpBuilder &builder,
874 return builder.
create<memref::DeallocOp>(alloc.
getLoc(), alloc)
878 std::optional<Value> CloneOp::buildClone(
OpBuilder &builder,
Value alloc) {
879 return builder.
create<CloneOp>(alloc.
getLoc(), alloc).getResult();
886 LogicalResult DeallocOp::inferReturnTypes(
887 MLIRContext *context, std::optional<::mlir::Location> location,
890 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
897 if (getMemrefs().size() != getConditions().size())
899 "must have the same number of conditions as memrefs to deallocate");
900 if (getRetained().size() != getUpdatedConditions().size())
901 return emitOpError(
"must have the same number of updated conditions "
902 "(results) as retained operands");
910 if (deallocOp.getMemrefs() == memrefs &&
911 deallocOp.getConditions() == conditions)
915 deallocOp.getMemrefsMutable().assign(memrefs);
916 deallocOp.getConditionsMutable().assign(conditions);
936 struct DeallocRemoveDuplicateDeallocMemrefs
940 LogicalResult matchAndRewrite(DeallocOp deallocOp,
945 for (
auto [i, memref, cond] :
947 if (memrefToCondition.count(memref)) {
950 Value &newCond = newConditions[memrefToCondition[memref]];
953 rewriter.
create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
955 memrefToCondition.insert({memref, newConditions.size()});
956 newMemrefs.push_back(memref);
957 newConditions.push_back(cond);
978 struct DeallocRemoveDuplicateRetainedMemrefs
982 LogicalResult matchAndRewrite(DeallocOp deallocOp,
989 for (
auto retained : deallocOp.getRetained()) {
990 if (seen.count(retained)) {
991 resultReplacementIdx.push_back(seen[retained]);
996 newRetained.push_back(retained);
997 resultReplacementIdx.push_back(i++);
1002 if (newRetained.size() == deallocOp.getRetained().size())
1008 rewriter.
create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1009 deallocOp.getConditions(), newRetained);
1011 llvm::map_range(resultReplacementIdx, [&](
unsigned idx) {
1012 return newDeallocOp.getUpdatedConditions()[idx];
1014 rewriter.
replaceOp(deallocOp, replacements);
1027 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1029 if (deallocOp.getMemrefs().empty()) {
1030 Value constFalse = rewriter.
create<arith::ConstantOp>(
1056 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1059 for (
auto [memref, cond] :
1060 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1062 newMemrefs.push_back(memref);
1063 newConditions.push_back(cond);
1093 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1096 llvm::map_range(deallocOp.getMemrefs(), [&](
Value memref) {
1097 auto extractStridedOp =
1098 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1099 if (!extractStridedOp)
1101 Value allocMemref = extractStridedOp.getOperand();
1102 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1105 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1111 deallocOp.getConditions(), rewriter);
1129 struct RemoveAllocDeallocPairWhenNoOtherUsers
1133 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1137 for (
auto [memref, cond] :
1138 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1139 if (
auto allocOp = memref.
getDefiningOp<MemoryEffectOpInterface>()) {
1144 hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1146 toDelete.push_back(allocOp);
1151 newMemrefs.push_back(memref);
1152 newConditions.push_back(cond);
1175 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1176 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1177 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1178 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1185 #define GET_OP_CLASSES
1186 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Options for BufferizableOpInterface-based bufferization.