33 struct CastOpInterface
34 :
public BufferizableOpInterface::ExternalModel<CastOpInterface,
54 auto castOp = cast<tensor::CastOp>(op);
56 castOp.getSource(),
options, invocationStack);
57 if (
failed(maybeSrcBufferType))
59 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
65 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
72 if (isa<UnrankedTensorType>(castOp.getType())) {
78 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
80 rankedResultType.getShape(), rankedResultType.getElementType(),
81 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
86 auto castOp = cast<tensor::CastOp>(op);
95 auto resultMemRefType =
97 if (
failed(resultMemRefType))
99 if (resultBuffer->getType() == *resultMemRefType) {
106 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
107 *resultMemRefType) &&
108 "CallOp::bufferize: cast incompatible");
109 replaceOpWithNewBufferizedOp<memref::CastOp>(
110 rewriter, op, *resultMemRefType, *resultBuffer);
117 struct CollapseShapeOpInterface
118 :
public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
119 tensor::CollapseShapeOp> {
143 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
145 collapseShapeOp.getSrc(),
options, invocationStack);
146 if (
failed(maybeSrcBufferType))
148 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
149 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
150 srcBufferType, collapseShapeOp.getReassociationIndices());
152 if (!canBeCollapsed) {
154 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
156 tensorResultType, srcBufferType.getMemorySpace());
159 return memref::CollapseShapeOp::computeCollapsedType(
160 srcBufferType, collapseShapeOp.getReassociationIndices());
165 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
166 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
171 Value buffer = *maybeBuffer;
172 auto bufferType = cast<MemRefType>(buffer.
getType());
174 if (tensorResultType.getRank() == 0) {
176 MemRefType resultType;
178 if (bufferType.getLayout().isIdentity()) {
180 MemRefLayoutAttrInterface layout;
182 layout, bufferType.getMemorySpace());
191 {}, tensorResultType.getElementType(),
193 bufferType.getMemorySpace());
196 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
197 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
204 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
205 bufferType, collapseShapeOp.getReassociationIndices());
206 if (!canBeCollapsed) {
215 collapseShapeOp.getSrcType().getElementType(),
216 AffineMap(), bufferType.getMemorySpace());
217 buffer = rewriter.
create<bufferization::ToMemrefOp>(
218 op->
getLoc(), memrefType, *tensorAlloc);
222 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
223 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
229 struct DimOpInterface
230 :
public BufferizableOpInterface::ExternalModel<DimOpInterface,
250 auto dimOp = cast<tensor::DimOp>(op);
254 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
261 struct EmptyOpInterface
262 :
public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
264 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
274 auto emptyOp = cast<tensor::EmptyOp>(op);
284 rewriter, op->
getLoc(), emptyOp.getResult(),
options,
false);
293 struct ExpandShapeOpInterface
294 :
public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
295 tensor::ExpandShapeOp> {
316 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
318 expandShapeOp.getSrc(),
options, invocationStack);
319 if (
failed(maybeSrcBufferType))
321 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
322 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
323 srcBufferType, expandShapeOp.getResultType().getShape(),
324 expandShapeOp.getReassociationIndices());
325 if (
failed(maybeResultType))
327 return *maybeResultType;
332 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
333 auto tensorResultType = expandShapeOp.getResultType();
341 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
342 rewriter, op, tensorResultType.getShape(), *buffer,
343 expandShapeOp.getReassociationIndices());
349 struct ExtractSliceOpInterface
350 :
public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
351 tensor::ExtractSliceOp> {
369 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
373 Location loc = extractSliceOp.getLoc();
382 auto resultMemrefType =
384 if (
failed(resultMemrefType))
386 Value subView = rewriter.
create<memref::SubViewOp>(
387 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
388 mixedSizes, mixedStrides);
397 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
398 assert(value == extractSliceOp.getResult() &&
"invalid value");
400 extractSliceOp.getSource(),
options, invocationStack);
401 if (
failed(srcMemrefType))
406 return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
407 extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
408 mixedOffsets, mixedSizes, mixedStrides));
413 struct ExtractOpInterface
414 :
public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
433 auto extractOp = cast<tensor::ExtractOp>(op);
438 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
439 extractOp.getIndices());
449 OperandRange::iterator &elementIt,
451 if (dim ==
static_cast<int>(shape.size()) - 1) {
452 for (
int i = 0; i < shape.back(); ++i) {
453 indices.back() = constants[i];
454 rewriter.
create<memref::StoreOp>(loc, *elementIt, buffer, indices);
459 for (
int i = 0; i < shape[dim]; ++i) {
460 indices[dim] = constants[i];
461 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
467 struct FromElementsOpInterface
468 :
public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
469 tensor::FromElementsOp> {
471 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
475 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
476 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
480 return op->
emitError(
"memory space not implemented yet");
484 auto shape = tensorType.getShape();
487 rewriter, loc, fromElementsOp.getResult(),
options,
493 Value buffer = rewriter.
create<bufferization::ToMemrefOp>(
494 op->
getLoc(), memrefType, *tensorAlloc);
497 if (fromElementsOp.getElements().empty()) {
504 rewriter.
create<memref::StoreOp>(
505 loc, fromElementsOp.getElements().front(), buffer);
511 auto maxDim = *llvm::max_element(shape);
513 constants.reserve(maxDim);
514 for (
int i = 0; i < maxDim; ++i)
515 constants.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, i));
518 auto elementIt = fromElementsOp.getElements().begin();
520 createStores(rewriter, loc, 0, buffer, shape, constants, elementIt,
551 Value tensorDestination,
554 assert(generateBody.
hasOneBlock() &&
"expected body with single block");
555 auto tensorType = cast<RankedTensorType>(tensorDestination.
getType());
564 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
569 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
570 indices.push_back(rewriter.
create<linalg::IndexOp>(loc, dim));
574 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
577 return linalgOp.getResult()[0];
581 struct GenerateOpInterface
582 :
public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
583 tensor::GenerateOp> {
585 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
589 auto generateOp = cast<tensor::GenerateOp>(op);
591 auto type = generateOp.getResult().getType();
595 return op->
emitError(
"memory space not implemented yet");
600 rewriter, loc, generateOp.getResult(),
options,
605 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
606 generateOp.getDynamicExtents(),
607 generateOp.getBody());
618 struct InsertOpInterface
623 auto insertOp = cast<tensor::InsertOp>(op);
628 rewriter.
create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
629 *destMemref, insertOp.getIndices());
640 struct InsertSliceOpInterface
642 tensor::InsertSliceOp> {
645 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
646 RankedTensorType destType = insertSliceOp.getDestType();
649 if (opOperand == insertSliceOp.getSourceMutable())
653 assert(opOperand == insertSliceOp.getDestMutable() &&
"expected dest");
657 bool allOffsetsZero =
658 llvm::all_of(insertSliceOp.getMixedOffsets(), [](
OpFoldResult ofr) {
659 return isConstantIntValue(ofr, 0);
661 bool sizesMatchDestSizes = llvm::all_of(
663 return getConstantIntValue(it.value()) ==
664 destType.getDimSize(it.index());
667 llvm::all_of(insertSliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
668 return isConstantIntValue(ofr, 1);
670 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
680 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
684 Location loc = insertSliceOp.getLoc();
693 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
694 auto subviewMemRefType =
695 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
696 insertSliceOp.getSourceType().getShape(), dstMemrefType,
697 mixedOffsets, mixedSizes, mixedStrides));
698 Value subView = rewriter.
create<memref::SubViewOp>(
699 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
708 if (
failed(
options.createMemCpy(rewriter, loc, *srcMemref, subView)))
720 struct PadOpInterface
721 :
public BufferizableOpInterface::ExternalModel<PadOpInterface,
723 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
744 auto padOp = cast<tensor::PadOp>(op);
746 padOp.getSource(),
options, invocationStack);
747 if (
failed(maybeSrcBufferType))
749 MemRefLayoutAttrInterface layout;
751 padOp.getResultType().getElementType(), layout,
752 maybeSrcBufferType->getMemorySpace());
757 auto padOp = cast<tensor::PadOp>(op);
759 RankedTensorType resultType = padOp.getResultType();
760 RankedTensorType srcType = padOp.getSourceType();
764 return ofr.get<
Value>();
774 for (int64_t i = 0; i < resultType.getRank(); ++i) {
775 if (!resultType.isDynamicDim(i))
777 Value srcDim = rewriter.
create<tensor::DimOp>(loc, padOp.getSource(), i);
778 Value lowPad = toValue(mixedLowPad[i]);
779 Value highPad = toValue(mixedHighPad[i]);
783 Value sum = rewriter.
create<affine::AffineApplyOp>(
784 loc, sumExpr,
ValueRange{srcDim, lowPad, highPad});
785 dynamicSizes.push_back(sum);
798 Value filledBuffer = lowerGenerateLikeOpBody(
799 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
807 padOp, padOp.getSource(), filledBuffer,
808 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
815 struct RankOpInterface
816 :
public BufferizableOpInterface::ExternalModel<RankOpInterface,
836 auto rankOp = cast<tensor::RankOp>(op);
840 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
847 struct ReshapeOpInterface
848 :
public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
853 auto reshapeOp = cast<tensor::ReshapeOp>(op);
854 return opOperand == reshapeOp.getShapeMutable();
869 auto reshapeOp = cast<tensor::ReshapeOp>(op);
876 auto maybeResultMemRefType =
878 if (
failed(maybeResultMemRefType))
884 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
885 if (srcType && !srcType.getLayout().isIdentity()) {
891 srcType.getShape(), srcType.getElementType(),
AffineMap(),
892 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
894 .
create<bufferization::ToMemrefOp>(
895 op->
getLoc(), memrefType, *tensorAlloc)
899 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
900 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
907 auto reshapeOp = cast<tensor::ReshapeOp>(op);
908 assert(value == reshapeOp.getResult() &&
"unexpected value provided");
910 reshapeOp.getSource(),
options, invocationStack);
911 if (
failed(maybeSourceBufferType))
914 reshapeOp.getResult().getType(),
915 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
920 struct ParallelInsertSliceOpInterface
921 :
public BufferizableOpInterface::ExternalModel<
922 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
935 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
936 return opOperand == parallelInsertSliceOp.getDestMutable();
942 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
943 ParallelCombiningOpInterface parallelCombiningParent =
944 parallelInsertSliceOp.getParallelCombiningParent();
960 auto destBufferType = cast<MemRefType>(destBuffer->getType());
961 auto subviewMemRefType =
962 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
963 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
964 parallelInsertSliceOp.getMixedOffsets(),
965 parallelInsertSliceOp.getMixedSizes(),
966 parallelInsertSliceOp.getMixedStrides()));
967 Value subview = rewriter.
create<memref::SubViewOp>(
968 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
969 parallelInsertSliceOp.getMixedOffsets(),
970 parallelInsertSliceOp.getMixedSizes(),
971 parallelInsertSliceOp.getMixedStrides());
974 if (
failed(
options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
975 *srcBuffer, subview)))
985 for (
Operation *user : srcBuffer->getUsers()) {
986 if (hasEffect<MemoryEffects::Free>(user)) {
987 if (user->getBlock() == parallelCombiningParent->getBlock())
988 rewriter.
moveOpBefore(user, user->getBlock()->getTerminator());
1001 struct SplatOpInterface
1002 :
public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1005 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
1010 auto splatOp = cast<tensor::SplatOp>(op);
1015 rewriter, loc, splatOp.getResult(),
options,
1021 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1025 return op->
emitError(
"memory space not implemented yet");
1030 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1034 rewriter.
create<linalg::YieldOp>(loc, splatOp.getInput());
1035 rewriter.
replaceOp(splatOp, linalgOp.getResult()[0]);
1048 CastOp::attachInterface<CastOpInterface>(*ctx);
1049 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1050 DimOp::attachInterface<DimOpInterface>(*ctx);
1051 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1052 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1053 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1054 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1055 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1056 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1057 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1058 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1059 PadOp::attachInterface<PadOpInterface>(*ctx);
1060 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1062 RankOp::attachInterface<RankOpInterface>(*ctx);
1063 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1064 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1067 ctx->
loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...