33 struct CastOpInterface
34 :
public BufferizableOpInterface::ExternalModel<CastOpInterface,
54 auto castOp = cast<tensor::CastOp>(op);
56 castOp.getSource(),
options, invocationStack);
57 if (
failed(maybeSrcBufferType))
59 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
65 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
72 if (isa<UnrankedTensorType>(castOp.getType())) {
78 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
80 rankedResultType.getShape(), rankedResultType.getElementType(),
81 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
86 auto castOp = cast<tensor::CastOp>(op);
95 auto resultMemRefType =
97 if (
failed(resultMemRefType))
99 if (resultBuffer->getType() == *resultMemRefType) {
106 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
107 *resultMemRefType) &&
108 "CallOp::bufferize: cast incompatible");
109 replaceOpWithNewBufferizedOp<memref::CastOp>(
110 rewriter, op, *resultMemRefType, *resultBuffer);
117 struct CollapseShapeOpInterface
118 :
public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
119 tensor::CollapseShapeOp> {
143 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
145 collapseShapeOp.getSrc(),
options, invocationStack);
146 if (
failed(maybeSrcBufferType))
148 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
149 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
150 srcBufferType, collapseShapeOp.getReassociationIndices());
152 if (!canBeCollapsed) {
154 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
156 tensorResultType, srcBufferType.getMemorySpace());
159 return memref::CollapseShapeOp::computeCollapsedType(
160 srcBufferType, collapseShapeOp.getReassociationIndices());
165 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
166 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
171 Value buffer = *maybeBuffer;
172 auto bufferType = cast<MemRefType>(buffer.
getType());
174 if (tensorResultType.getRank() == 0) {
176 MemRefType resultType;
178 if (bufferType.getLayout().isIdentity()) {
180 MemRefLayoutAttrInterface layout;
182 layout, bufferType.getMemorySpace());
191 {}, tensorResultType.getElementType(),
193 bufferType.getMemorySpace());
196 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
197 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
204 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
205 bufferType, collapseShapeOp.getReassociationIndices());
206 if (!canBeCollapsed) {
215 collapseShapeOp.getSrcType().getElementType(),
216 AffineMap(), bufferType.getMemorySpace());
217 buffer = rewriter.
create<bufferization::ToMemrefOp>(
218 op->
getLoc(), memrefType, *tensorAlloc);
222 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
223 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
229 struct DimOpInterface
230 :
public BufferizableOpInterface::ExternalModel<DimOpInterface,
250 auto dimOp = cast<tensor::DimOp>(op);
254 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
261 struct EmptyOpInterface
262 :
public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
264 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
274 auto emptyOp = cast<tensor::EmptyOp>(op);
284 rewriter, op->
getLoc(), emptyOp.getResult(),
options,
false);
293 struct ExpandShapeOpInterface
294 :
public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
295 tensor::ExpandShapeOp> {
316 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
318 expandShapeOp.getSrc(),
options, invocationStack);
319 if (
failed(maybeSrcBufferType))
321 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
322 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
323 srcBufferType, expandShapeOp.getResultType().getShape(),
324 expandShapeOp.getReassociationIndices());
325 if (
failed(maybeResultType))
327 return *maybeResultType;
332 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
333 auto tensorResultType = expandShapeOp.getResultType();
341 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
342 rewriter, op, tensorResultType.getShape(), *buffer,
343 expandShapeOp.getReassociationIndices());
349 struct ExtractSliceOpInterface
350 :
public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
351 tensor::ExtractSliceOp> {
369 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
373 Location loc = extractSliceOp.getLoc();
382 auto resultMemrefType =
384 if (
failed(resultMemrefType))
386 Value subView = rewriter.
create<memref::SubViewOp>(
387 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
388 mixedSizes, mixedStrides);
397 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
398 assert(value == extractSliceOp.getResult() &&
"invalid value");
400 extractSliceOp.getSource(),
options, invocationStack);
401 if (
failed(srcMemrefType))
406 return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
407 extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
408 mixedOffsets, mixedSizes, mixedStrides));
413 struct ExtractOpInterface
414 :
public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
433 auto extractOp = cast<tensor::ExtractOp>(op);
438 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
439 extractOp.getIndices());
449 OperandRange::iterator &elementIt,
451 if (dim ==
static_cast<int>(shape.size()) - 1) {
452 for (
int i = 0; i < shape.back(); ++i) {
453 indices.back() = constants[i];
454 rewriter.
create<memref::StoreOp>(loc, *elementIt, buffer, indices);
459 for (
int i = 0; i < shape[dim]; ++i) {
460 indices[dim] = constants[i];
461 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
467 struct FromElementsOpInterface
468 :
public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
469 tensor::FromElementsOp> {
471 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
475 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
479 return op->
emitError(
"memory space not implemented yet");
483 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
484 auto shape = tensorType.getShape();
487 rewriter, loc, fromElementsOp.getResult(),
options,
493 Value buffer = rewriter.
create<bufferization::ToMemrefOp>(
494 op->
getLoc(), memrefType, *tensorAlloc);
497 if (fromElementsOp.getElements().empty()) {
504 rewriter.
create<memref::StoreOp>(
505 loc, fromElementsOp.getElements().front(), buffer);
511 auto maxDim = *std::max_element(shape.begin(), shape.end());
513 constants.reserve(maxDim);
514 for (
int i = 0; i < maxDim; ++i)
515 constants.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, i));
518 auto elementIt = fromElementsOp.getElements().begin();
520 createStores(rewriter, loc, 0, buffer, shape, constants, elementIt,
551 Value tensorDestination,
554 assert(generateBody.
hasOneBlock() &&
"expected body with single block");
555 auto tensorType = cast<RankedTensorType>(tensorDestination.
getType());
564 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
569 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
570 indices.push_back(rewriter.
create<linalg::IndexOp>(loc, dim));
574 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
577 return linalgOp.getResult()[0];
581 struct GenerateOpInterface
582 :
public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
583 tensor::GenerateOp> {
585 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
589 auto generateOp = cast<tensor::GenerateOp>(op);
593 return op->
emitError(
"memory space not implemented yet");
598 rewriter, loc, generateOp.getResult(),
options,
603 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
604 generateOp.getDynamicExtents(),
605 generateOp.getBody());
616 struct InsertOpInterface
621 auto insertOp = cast<tensor::InsertOp>(op);
626 rewriter.
create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
627 *destMemref, insertOp.getIndices());
638 struct InsertSliceOpInterface
640 tensor::InsertSliceOp> {
643 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
644 RankedTensorType destType = insertSliceOp.getDestType();
647 if (opOperand == insertSliceOp.getSourceMutable())
651 assert(opOperand == insertSliceOp.getDestMutable() &&
"expected dest");
655 bool allOffsetsZero =
656 llvm::all_of(insertSliceOp.getMixedOffsets(), [](
OpFoldResult ofr) {
657 return isConstantIntValue(ofr, 0);
659 bool sizesMatchDestSizes = llvm::all_of(
661 return getConstantIntValue(it.value()) ==
662 destType.getDimSize(it.index());
665 llvm::all_of(insertSliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
666 return isConstantIntValue(ofr, 1);
668 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
678 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
682 Location loc = insertSliceOp.getLoc();
691 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
692 auto subviewMemRefType =
693 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
694 insertSliceOp.getSourceType().getShape(), dstMemrefType,
695 mixedOffsets, mixedSizes, mixedStrides));
696 Value subView = rewriter.
create<memref::SubViewOp>(
697 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
706 if (
failed(
options.createMemCpy(rewriter, loc, *srcMemref, subView)))
718 struct PadOpInterface
719 :
public BufferizableOpInterface::ExternalModel<PadOpInterface,
721 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
742 auto padOp = cast<tensor::PadOp>(op);
744 padOp.getSource(),
options, invocationStack);
745 if (
failed(maybeSrcBufferType))
747 MemRefLayoutAttrInterface layout;
749 padOp.getResultType().getElementType(), layout,
750 maybeSrcBufferType->getMemorySpace());
755 auto padOp = cast<tensor::PadOp>(op);
757 RankedTensorType resultType = padOp.getResultType();
758 RankedTensorType srcType = padOp.getSourceType();
762 return ofr.get<
Value>();
772 for (int64_t i = 0; i < resultType.getRank(); ++i) {
773 if (!resultType.isDynamicDim(i))
775 Value srcDim = rewriter.
create<tensor::DimOp>(loc, padOp.getSource(), i);
776 Value lowPad = toValue(mixedLowPad[i]);
777 Value highPad = toValue(mixedHighPad[i]);
781 Value sum = rewriter.
create<affine::AffineApplyOp>(
782 loc, sumExpr,
ValueRange{srcDim, lowPad, highPad});
783 dynamicSizes.push_back(sum);
796 Value filledBuffer = lowerGenerateLikeOpBody(
797 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
805 padOp, padOp.getSource(), filledBuffer,
806 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
813 struct RankOpInterface
814 :
public BufferizableOpInterface::ExternalModel<RankOpInterface,
834 auto rankOp = cast<tensor::RankOp>(op);
838 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
845 struct ReshapeOpInterface
846 :
public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
851 auto reshapeOp = cast<tensor::ReshapeOp>(op);
852 return opOperand == reshapeOp.getShapeMutable();
867 auto reshapeOp = cast<tensor::ReshapeOp>(op);
874 auto maybeResultMemRefType =
876 if (
failed(maybeResultMemRefType))
882 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
883 if (srcType && !srcType.getLayout().isIdentity()) {
889 srcType.getShape(), srcType.getElementType(),
AffineMap(),
890 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
892 .
create<bufferization::ToMemrefOp>(
893 op->
getLoc(), memrefType, *tensorAlloc)
897 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
898 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
905 auto reshapeOp = cast<tensor::ReshapeOp>(op);
906 assert(value == reshapeOp.getResult() &&
"unexpected value provided");
908 reshapeOp.getSource(),
options, invocationStack);
909 if (
failed(maybeSourceBufferType))
912 reshapeOp.getResult().getType(),
913 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
918 struct ParallelInsertSliceOpInterface
919 :
public BufferizableOpInterface::ExternalModel<
920 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
933 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
934 return opOperand == parallelInsertSliceOp.getDestMutable();
940 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
941 ParallelCombiningOpInterface parallelCombiningParent =
942 parallelInsertSliceOp.getParallelCombiningParent();
958 auto destBufferType = cast<MemRefType>(destBuffer->getType());
959 auto subviewMemRefType =
960 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
961 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
962 parallelInsertSliceOp.getMixedOffsets(),
963 parallelInsertSliceOp.getMixedSizes(),
964 parallelInsertSliceOp.getMixedStrides()));
965 Value subview = rewriter.
create<memref::SubViewOp>(
966 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
967 parallelInsertSliceOp.getMixedOffsets(),
968 parallelInsertSliceOp.getMixedSizes(),
969 parallelInsertSliceOp.getMixedStrides());
972 if (
failed(
options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
973 *srcBuffer, subview)))
983 for (
Operation *user : srcBuffer->getUsers()) {
984 if (hasEffect<MemoryEffects::Free>(user)) {
985 if (user->getBlock() == parallelCombiningParent->getBlock())
986 user->moveBefore(user->getBlock()->getTerminator());
999 struct SplatOpInterface
1000 :
public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1003 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
1008 auto splatOp = cast<tensor::SplatOp>(op);
1012 return op->
emitError(
"memory space not implemented yet");
1017 rewriter, loc, splatOp.getResult(),
options,
1023 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1027 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1031 rewriter.
create<linalg::YieldOp>(loc, splatOp.getInput());
1032 rewriter.
replaceOp(splatOp, linalgOp.getResult()[0]);
1045 CastOp::attachInterface<CastOpInterface>(*ctx);
1046 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1047 DimOp::attachInterface<DimOpInterface>(*ctx);
1048 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1049 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1050 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1051 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1052 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1053 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1054 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1055 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1056 PadOp::attachInterface<PadOpInterface>(*ctx);
1057 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1059 RankOp::attachInterface<RankOpInterface>(*ctx);
1060 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1061 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1064 ctx->
loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...