33 struct CastOpInterface
34 :
public BufferizableOpInterface::ExternalModel<CastOpInterface,
51 FailureOr<BaseMemRefType>
54 auto castOp = cast<tensor::CastOp>(op);
56 castOp.getSource(),
options, invocationStack);
57 if (failed(maybeSrcBufferType))
59 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
65 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
72 if (isa<UnrankedTensorType>(castOp.getType())) {
78 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
80 rankedResultType.getShape(), rankedResultType.getElementType(),
81 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
86 auto castOp = cast<tensor::CastOp>(op);
89 FailureOr<Value> resultBuffer =
91 if (failed(resultBuffer))
95 auto resultMemRefType =
97 if (failed(resultMemRefType))
99 if (resultBuffer->getType() == *resultMemRefType) {
106 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
107 *resultMemRefType) &&
108 "CallOp::bufferize: cast incompatible");
109 replaceOpWithNewBufferizedOp<memref::CastOp>(
110 rewriter, op, *resultMemRefType, *resultBuffer);
117 struct CollapseShapeOpInterface
118 :
public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
119 tensor::CollapseShapeOp> {
140 FailureOr<BaseMemRefType>
143 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
145 collapseShapeOp.getSrc(),
options, invocationStack);
146 if (failed(maybeSrcBufferType))
148 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
149 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
150 srcBufferType, collapseShapeOp.getReassociationIndices());
152 if (!canBeCollapsed) {
154 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
156 tensorResultType, srcBufferType.getMemorySpace());
159 return memref::CollapseShapeOp::computeCollapsedType(
160 srcBufferType, collapseShapeOp.getReassociationIndices());
165 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
166 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
167 FailureOr<Value> maybeBuffer =
169 if (failed(maybeBuffer))
171 Value buffer = *maybeBuffer;
172 auto bufferType = cast<MemRefType>(buffer.
getType());
174 if (tensorResultType.getRank() == 0) {
176 MemRefType resultType;
178 if (bufferType.getLayout().isIdentity()) {
180 MemRefLayoutAttrInterface layout;
182 layout, bufferType.getMemorySpace());
191 {}, tensorResultType.getElementType(),
193 bufferType.getMemorySpace());
196 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
197 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
204 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
205 bufferType, collapseShapeOp.getReassociationIndices());
206 if (!canBeCollapsed) {
211 if (failed(tensorAlloc))
215 collapseShapeOp.getSrcType().getElementType(),
216 AffineMap(), bufferType.getMemorySpace());
217 buffer = rewriter.
create<bufferization::ToMemrefOp>(
218 op->
getLoc(), memrefType, *tensorAlloc);
222 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
223 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
229 struct DimOpInterface
230 :
public BufferizableOpInterface::ExternalModel<DimOpInterface,
250 auto dimOp = cast<tensor::DimOp>(op);
254 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
261 struct EmptyOpInterface
262 :
public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
264 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
274 auto emptyOp = cast<tensor::EmptyOp>(op);
284 rewriter, op->
getLoc(), emptyOp.getResult(),
options,
false);
285 if (failed(allocTensor))
293 struct ExpandShapeOpInterface
294 :
public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
295 tensor::ExpandShapeOp> {
313 FailureOr<BaseMemRefType>
316 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
318 expandShapeOp.getSrc(),
options, invocationStack);
319 if (failed(maybeSrcBufferType))
321 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
322 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
323 srcBufferType, expandShapeOp.getResultType().getShape(),
324 expandShapeOp.getReassociationIndices());
325 if (failed(maybeResultType))
327 return *maybeResultType;
332 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
333 auto tensorResultType = expandShapeOp.getResultType();
334 FailureOr<Value> buffer =
344 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
345 rewriter, op, tensorResultType.getShape(), *buffer,
346 expandShapeOp.getReassociationIndices());
352 struct ExtractSliceOpInterface
353 :
public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
354 tensor::ExtractSliceOp> {
372 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
376 Location loc = extractSliceOp.getLoc();
379 FailureOr<Value> srcMemref =
381 if (failed(srcMemref))
385 auto resultMemrefType =
387 if (failed(resultMemrefType))
389 Value subView = rewriter.
create<memref::SubViewOp>(
390 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
391 mixedOffsets, mixedSizes, mixedStrides);
397 FailureOr<BaseMemRefType>
400 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
401 assert(value == extractSliceOp.getResult() &&
"invalid value");
403 extractSliceOp.getSource(),
options, invocationStack);
404 if (failed(srcMemrefType))
409 return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
410 extractSliceOp.getType().getShape(),
411 llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
417 struct ExtractOpInterface
418 :
public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
437 auto extractOp = cast<tensor::ExtractOp>(op);
438 FailureOr<Value> srcMemref =
440 if (failed(srcMemref))
442 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
443 extractOp.getIndices());
453 OperandRange::iterator &elementIt,
455 if (dim ==
static_cast<int>(shape.size()) - 1) {
456 for (
int i = 0; i < shape.back(); ++i) {
457 indices.back() = constants[i];
458 rewriter.
create<memref::StoreOp>(loc, *elementIt, buffer, indices);
463 for (
int i = 0; i < shape[dim]; ++i) {
464 indices[dim] = constants[i];
465 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
471 struct FromElementsOpInterface
472 :
public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
473 tensor::FromElementsOp> {
475 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
479 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
480 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
484 return op->
emitError(
"memory space not implemented yet");
488 auto shape = tensorType.getShape();
491 rewriter, loc, fromElementsOp.getResult(),
options,
493 if (failed(tensorAlloc))
497 Value buffer = rewriter.
create<bufferization::ToMemrefOp>(
498 op->
getLoc(), memrefType, *tensorAlloc);
501 if (fromElementsOp.getElements().empty()) {
508 rewriter.
create<memref::StoreOp>(
509 loc, fromElementsOp.getElements().front(), buffer);
515 auto maxDim = *llvm::max_element(shape);
517 constants.reserve(maxDim);
518 for (
int i = 0; i < maxDim; ++i)
519 constants.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, i));
522 auto elementIt = fromElementsOp.getElements().begin();
524 createStores(rewriter, loc, 0, buffer, shape, constants, elementIt,
555 Value tensorDestination,
558 assert(generateBody.
hasOneBlock() &&
"expected body with single block");
559 auto tensorType = cast<RankedTensorType>(tensorDestination.
getType());
568 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
573 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
574 indices.push_back(rewriter.
create<linalg::IndexOp>(loc, dim));
578 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
581 return linalgOp.getResult()[0];
585 struct GenerateOpInterface
586 :
public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
587 tensor::GenerateOp> {
589 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
593 auto generateOp = cast<tensor::GenerateOp>(op);
595 auto type = generateOp.getResult().getType();
599 return op->
emitError(
"memory space not implemented yet");
604 rewriter, loc, generateOp.getResult(),
options,
606 if (failed(tensorAlloc))
609 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
610 generateOp.getDynamicExtents(),
611 generateOp.getBody());
622 struct InsertOpInterface
627 auto insertOp = cast<tensor::InsertOp>(op);
628 FailureOr<Value> destMemref =
630 if (failed(destMemref))
632 rewriter.
create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
633 *destMemref, insertOp.getIndices());
644 struct InsertSliceOpInterface
646 tensor::InsertSliceOp> {
649 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
650 RankedTensorType destType = insertSliceOp.getDestType();
653 if (opOperand == insertSliceOp.getSourceMutable())
657 assert(opOperand == insertSliceOp.getDestMutable() &&
"expected dest");
661 bool allOffsetsZero =
662 llvm::all_of(insertSliceOp.getMixedOffsets(), [](
OpFoldResult ofr) {
663 return isConstantIntValue(ofr, 0);
665 bool sizesMatchDestSizes = llvm::all_of(
667 return getConstantIntValue(it.value()) ==
668 destType.getDimSize(it.index());
671 llvm::all_of(insertSliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
672 return isConstantIntValue(ofr, 1);
674 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
684 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
688 Location loc = insertSliceOp.getLoc();
691 FailureOr<Value> dstMemref =
693 if (failed(dstMemref))
697 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
698 auto subviewMemRefType =
699 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
700 insertSliceOp.getSourceType().getShape(), dstMemrefType,
701 mixedOffsets, mixedSizes, mixedStrides));
702 Value subView = rewriter.
create<memref::SubViewOp>(
703 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
708 FailureOr<Value> srcMemref =
710 if (failed(srcMemref))
712 if (failed(
options.createMemCpy(rewriter, loc, *srcMemref, subView)))
724 struct PadOpInterface
725 :
public BufferizableOpInterface::ExternalModel<PadOpInterface,
727 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
744 FailureOr<BaseMemRefType>
748 auto padOp = cast<tensor::PadOp>(op);
750 padOp.getSource(),
options, invocationStack);
751 if (failed(maybeSrcBufferType))
753 MemRefLayoutAttrInterface layout;
755 padOp.getResultType().getElementType(), layout,
756 maybeSrcBufferType->getMemorySpace());
761 auto padOp = cast<tensor::PadOp>(op);
763 RankedTensorType resultType = padOp.getResultType();
764 RankedTensorType srcType = padOp.getSourceType();
768 return ofr.get<
Value>();
778 for (int64_t i = 0; i < resultType.getRank(); ++i) {
779 if (!resultType.isDynamicDim(i))
781 Value srcDim = rewriter.
create<tensor::DimOp>(loc, padOp.getSource(), i);
782 Value lowPad = toValue(mixedLowPad[i]);
783 Value highPad = toValue(mixedHighPad[i]);
787 Value sum = rewriter.
create<affine::AffineApplyOp>(
788 loc, sumExpr,
ValueRange{srcDim, lowPad, highPad});
789 dynamicSizes.push_back(sum);
793 FailureOr<Value> tensorAlloc =
796 if (failed(tensorAlloc))
802 Value filledBuffer = lowerGenerateLikeOpBody(
803 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
811 padOp, padOp.getSource(), filledBuffer,
812 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
819 struct RankOpInterface
820 :
public BufferizableOpInterface::ExternalModel<RankOpInterface,
840 auto rankOp = cast<tensor::RankOp>(op);
844 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
851 struct ReshapeOpInterface
852 :
public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
857 auto reshapeOp = cast<tensor::ReshapeOp>(op);
858 return opOperand == reshapeOp.getShapeMutable();
873 auto reshapeOp = cast<tensor::ReshapeOp>(op);
874 FailureOr<Value> srcBuffer =
876 FailureOr<Value> shapeBuffer =
878 if (failed(srcBuffer) || failed(shapeBuffer))
880 auto maybeResultMemRefType =
882 if (failed(maybeResultMemRefType))
888 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
889 if (srcType && !srcType.getLayout().isIdentity()) {
892 if (failed(tensorAlloc))
895 srcType.getShape(), srcType.getElementType(),
AffineMap(),
896 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
898 .
create<bufferization::ToMemrefOp>(
899 op->
getLoc(), memrefType, *tensorAlloc)
903 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
904 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
908 FailureOr<BaseMemRefType>
911 auto reshapeOp = cast<tensor::ReshapeOp>(op);
912 assert(value == reshapeOp.getResult() &&
"unexpected value provided");
914 reshapeOp.getSource(),
options, invocationStack);
915 if (failed(maybeSourceBufferType))
918 reshapeOp.getResult().getType(),
919 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
924 struct ParallelInsertSliceOpInterface
925 :
public BufferizableOpInterface::ExternalModel<
926 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
939 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
940 return opOperand == parallelInsertSliceOp.getDestMutable();
946 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
947 ParallelCombiningOpInterface parallelCombiningParent =
948 parallelInsertSliceOp.getParallelCombiningParent();
954 FailureOr<Value> destBuffer =
956 if (failed(destBuffer))
958 FailureOr<Value> srcBuffer =
960 if (failed(srcBuffer))
964 auto destBufferType = cast<MemRefType>(destBuffer->getType());
965 auto subviewMemRefType =
966 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
967 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
968 parallelInsertSliceOp.getMixedOffsets(),
969 parallelInsertSliceOp.getMixedSizes(),
970 parallelInsertSliceOp.getMixedStrides()));
971 Value subview = rewriter.
create<memref::SubViewOp>(
972 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
973 parallelInsertSliceOp.getMixedOffsets(),
974 parallelInsertSliceOp.getMixedSizes(),
975 parallelInsertSliceOp.getMixedStrides());
978 if (failed(
options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
979 *srcBuffer, subview)))
990 if (hasEffect<MemoryEffects::Free>(user)) {
991 if (user->getBlock() == parallelCombiningParent->getBlock())
992 rewriter.
moveOpBefore(user, user->getBlock()->getTerminator());
1012 struct SplatOpInterface
1013 :
public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1016 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
1021 auto splatOp = cast<tensor::SplatOp>(op);
1026 rewriter, loc, splatOp.getResult(),
options,
1028 if (failed(tensorAlloc))
1032 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1036 return op->
emitError(
"memory space not implemented yet");
1041 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1045 rewriter.
create<linalg::YieldOp>(loc, splatOp.getInput());
1046 rewriter.
replaceOp(splatOp, linalgOp.getResult()[0]);
1059 CastOp::attachInterface<CastOpInterface>(*ctx);
1060 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1061 DimOp::attachInterface<DimOpInterface>(*ctx);
1062 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1063 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1064 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1065 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1066 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1067 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1068 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1069 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1070 PadOp::attachInterface<PadOpInterface>(*ctx);
1071 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1073 RankOp::attachInterface<RankOpInterface>(*ctx);
1074 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1075 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1078 ctx->
loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...