34 struct CastOpInterface
35 :
public BufferizableOpInterface::ExternalModel<CastOpInterface,
52 FailureOr<BufferLikeType>
56 auto castOp = cast<tensor::CastOp>(op);
57 auto maybeSrcBufferType =
59 castOp.getSource(),
options, state, invocationStack));
60 if (failed(maybeSrcBufferType))
62 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
68 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
71 return cast<BufferLikeType>(
76 if (isa<UnrankedTensorType>(castOp.getType())) {
77 return cast<BufferLikeType>(
83 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
85 rankedResultType.getShape(), rankedResultType.getElementType(),
86 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
92 auto castOp = cast<tensor::CastOp>(op);
95 FailureOr<Value> resultBuffer =
97 if (failed(resultBuffer))
101 auto resultMemRefType =
103 if (failed(resultMemRefType))
105 if (resultBuffer->getType() == *resultMemRefType) {
112 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
113 *resultMemRefType) &&
114 "CallOp::bufferize: cast incompatible");
115 replaceOpWithNewBufferizedOp<memref::CastOp>(
116 rewriter, op, *resultMemRefType, *resultBuffer);
123 struct CollapseShapeOpInterface
124 :
public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
125 tensor::CollapseShapeOp> {
146 FailureOr<BufferLikeType>
150 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
152 collapseShapeOp.getSrc(),
options, state, invocationStack);
153 if (failed(maybeSrcBufferType))
155 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
156 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
157 srcBufferType, collapseShapeOp.getReassociationIndices());
159 if (!canBeCollapsed) {
161 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
162 return cast<BufferLikeType>(
164 tensorResultType, srcBufferType.getMemorySpace()));
167 return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
168 srcBufferType, collapseShapeOp.getReassociationIndices()));
174 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
175 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
176 FailureOr<Value> maybeBuffer =
178 if (failed(maybeBuffer))
180 Value buffer = *maybeBuffer;
181 auto bufferType = cast<MemRefType>(buffer.
getType());
183 if (tensorResultType.getRank() == 0) {
185 MemRefType resultType;
187 if (bufferType.getLayout().isIdentity()) {
189 MemRefLayoutAttrInterface layout;
191 layout, bufferType.getMemorySpace());
197 if (failed(bufferType.getStridesAndOffset(strides, offset)))
200 {}, tensorResultType.getElementType(),
202 bufferType.getMemorySpace());
205 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
206 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
213 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
214 bufferType, collapseShapeOp.getReassociationIndices());
215 if (!canBeCollapsed) {
219 rewriter, op->
getLoc(), collapseShapeOp.getSrc(),
options, state);
220 if (failed(tensorAlloc))
224 collapseShapeOp.getSrcType().getElementType(),
225 AffineMap(), bufferType.getMemorySpace());
226 buffer = rewriter.
create<bufferization::ToBufferOp>(
227 op->
getLoc(), memrefType, *tensorAlloc);
231 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
232 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
238 struct DimOpInterface
239 :
public BufferizableOpInterface::ExternalModel<DimOpInterface,
260 auto dimOp = cast<tensor::DimOp>(op);
261 FailureOr<Value> v =
getBuffer(rewriter, dimOp.getSource(),
options, state);
264 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
271 struct EmptyOpInterface
272 :
public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
274 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
285 auto emptyOp = cast<tensor::EmptyOp>(op);
297 if (failed(allocTensor))
305 struct ExpandShapeOpInterface
306 :
public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
307 tensor::ExpandShapeOp> {
325 FailureOr<BufferLikeType>
329 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
331 expandShapeOp.getSrc(),
options, state, invocationStack);
332 if (failed(maybeSrcBufferType))
334 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
335 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
336 srcBufferType, expandShapeOp.getResultType().getShape(),
337 expandShapeOp.getReassociationIndices());
338 if (failed(maybeResultType))
340 return cast<BufferLikeType>(*maybeResultType);
346 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
347 auto tensorResultType = expandShapeOp.getResultType();
348 FailureOr<Value> buffer =
353 auto memrefExpandShape = rewriter.
create<memref::ExpandShapeOp>(
354 op->
getLoc(), tensorResultType.getShape(), *buffer,
355 expandShapeOp.getReassociationIndices(),
356 expandShapeOp.getMixedOutputShape());
358 memrefExpandShape->getResults());
364 struct ExtractSliceOpInterface
365 :
public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
366 tensor::ExtractSliceOp> {
385 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
389 Location loc = extractSliceOp.getLoc();
392 FailureOr<Value> srcMemref =
394 if (failed(srcMemref))
399 extractSliceOp.getResult(),
options, state);
400 if (failed(resultMemrefType))
402 Value subView = rewriter.
create<memref::SubViewOp>(
403 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
404 mixedOffsets, mixedSizes, mixedStrides);
410 FailureOr<BufferLikeType>
414 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
415 assert(value == extractSliceOp.getResult() &&
"invalid value");
417 extractSliceOp.getSource(),
options, state, invocationStack);
418 if (failed(srcMemrefType))
423 return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
424 extractSliceOp.getType().getShape(),
425 llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
431 struct ExtractOpInterface
432 :
public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
452 auto extractOp = cast<tensor::ExtractOp>(op);
453 FailureOr<Value> srcMemref =
455 if (failed(srcMemref))
457 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
458 extractOp.getIndices());
468 OperandRange::iterator &elementIt,
470 if (dim ==
static_cast<int>(shape.size()) - 1) {
471 for (
int i = 0; i < shape.back(); ++i) {
472 indices.back() = constants[i];
473 rewriter.
create<memref::StoreOp>(loc, *elementIt, buffer, indices);
478 for (
int i = 0; i < shape[dim]; ++i) {
479 indices[dim] = constants[i];
480 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
486 struct FromElementsOpInterface
487 :
public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
488 tensor::FromElementsOp> {
490 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
495 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
496 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
500 auto shape = tensorType.getShape();
503 rewriter, loc, fromElementsOp.getResult(),
options, state,
505 if (failed(tensorAlloc))
507 FailureOr<BufferLikeType> memrefType =
509 if (failed(memrefType))
511 Value buffer = rewriter.
create<bufferization::ToBufferOp>(
512 op->
getLoc(), *memrefType, *tensorAlloc);
515 if (fromElementsOp.getElements().empty()) {
522 rewriter.
create<memref::StoreOp>(
523 loc, fromElementsOp.getElements().front(), buffer);
529 auto maxDim = *llvm::max_element(shape);
531 constants.reserve(maxDim);
532 for (
int i = 0; i < maxDim; ++i)
533 constants.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, i));
536 auto elementIt = fromElementsOp.getElements().begin();
538 createStores(rewriter, loc, 0, buffer, shape, constants, elementIt,
569 Value tensorDestination,
572 assert(generateBody.
hasOneBlock() &&
"expected body with single block");
573 auto tensorType = cast<RankedTensorType>(tensorDestination.
getType());
582 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
587 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
588 indices.push_back(rewriter.
create<linalg::IndexOp>(loc, dim));
592 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
595 return linalgOp.getResult()[0];
599 struct GenerateOpInterface
600 :
public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
601 tensor::GenerateOp> {
603 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
608 auto generateOp = cast<tensor::GenerateOp>(op);
610 auto type = generateOp.getResult().getType();
614 return op->
emitError(
"memory space not implemented yet");
619 rewriter, loc, generateOp.getResult(),
options, state,
621 if (failed(tensorAlloc))
624 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
625 generateOp.getDynamicExtents(),
626 generateOp.getBody());
637 struct InsertOpInterface
643 auto insertOp = cast<tensor::InsertOp>(op);
644 FailureOr<Value> destMemref =
646 if (failed(destMemref))
648 rewriter.
create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
649 *destMemref, insertOp.getIndices());
655 template <
typename InsertOpTy>
656 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
659 if (opOperand == insertSliceOp.getSourceMutable())
663 assert(opOperand == insertSliceOp.getDestMutable() &&
"expected dest");
667 bool allOffsetsZero =
668 llvm::all_of(insertSliceOp.getMixedOffsets(),
isZeroInteger);
669 RankedTensorType destType = insertSliceOp.getDestType();
670 bool sizesMatchDestSizes =
674 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
682 struct InsertSliceOpInterface
684 tensor::InsertSliceOp> {
687 return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
699 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
703 Location loc = insertSliceOp.getLoc();
706 FailureOr<Value> dstMemref =
708 if (failed(dstMemref))
712 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
713 MemRefType subviewMemRefType =
714 memref::SubViewOp::inferRankReducedResultType(
715 insertSliceOp.getSourceType().getShape(), dstMemrefType,
716 mixedOffsets, mixedSizes, mixedStrides);
717 Value subView = rewriter.
create<memref::SubViewOp>(
718 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
723 FailureOr<Value> srcMemref =
725 if (failed(srcMemref))
727 if (failed(
options.createMemCpy(rewriter, loc, *srcMemref, subView)))
739 struct PadOpInterface
740 :
public BufferizableOpInterface::ExternalModel<PadOpInterface,
742 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
759 FailureOr<BufferLikeType>
764 auto padOp = cast<tensor::PadOp>(op);
765 auto maybeSrcBufferType =
767 padOp.getSource(),
options, state, invocationStack));
768 if (failed(maybeSrcBufferType))
770 MemRefLayoutAttrInterface layout;
771 return cast<BufferLikeType>(
773 padOp.getResultType().getElementType(), layout,
774 maybeSrcBufferType->getMemorySpace()));
780 auto padOp = cast<tensor::PadOp>(op);
782 RankedTensorType resultType = padOp.getResultType();
783 RankedTensorType srcType = padOp.getSourceType();
786 if (
auto value = dyn_cast<Value>(ofr))
797 for (int64_t i = 0; i < resultType.getRank(); ++i) {
798 if (!resultType.isDynamicDim(i))
800 Value srcDim = rewriter.
create<tensor::DimOp>(loc, padOp.getSource(), i);
801 Value lowPad = toValue(mixedLowPad[i]);
802 Value highPad = toValue(mixedHighPad[i]);
806 Value sum = rewriter.
create<affine::AffineApplyOp>(
807 loc, sumExpr,
ValueRange{srcDim, lowPad, highPad});
808 dynamicSizes.push_back(sum);
813 rewriter, loc, padOp.getResult(),
options, state,
815 if (failed(tensorAlloc))
821 Value filledBuffer = lowerGenerateLikeOpBody(
822 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
830 padOp, padOp.getSource(), filledBuffer,
831 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
838 struct RankOpInterface
839 :
public BufferizableOpInterface::ExternalModel<RankOpInterface,
860 auto rankOp = cast<tensor::RankOp>(op);
865 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
872 struct ReshapeOpInterface
873 :
public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
878 auto reshapeOp = cast<tensor::ReshapeOp>(op);
879 return opOperand == reshapeOp.getShapeMutable();
890 auto reshapeOp = cast<tensor::ReshapeOp>(op);
891 if (reshapeOp.getSourceMutable() != opOperand)
899 auto reshapeOp = cast<tensor::ReshapeOp>(op);
900 FailureOr<Value> srcBuffer =
902 FailureOr<Value> shapeBuffer =
904 if (failed(srcBuffer) || failed(shapeBuffer))
906 auto maybeResultMemRefType =
908 if (failed(maybeResultMemRefType))
914 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
915 if (srcType && !srcType.getLayout().isIdentity()) {
917 rewriter, op->
getLoc(), reshapeOp.getSource(),
options, state);
918 if (failed(tensorAlloc))
921 srcType.getShape(), srcType.getElementType(),
AffineMap(),
922 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
924 .
create<bufferization::ToBufferOp>(
925 op->
getLoc(), memrefType, *tensorAlloc)
929 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
930 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
934 FailureOr<BufferLikeType>
938 auto reshapeOp = cast<tensor::ReshapeOp>(op);
939 assert(value == reshapeOp.getResult() &&
"unexpected value provided");
941 reshapeOp.getSource(),
options, state, invocationStack);
942 if (failed(maybeSourceBufferType))
945 reshapeOp.getResult().getType(),
946 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
951 struct ParallelInsertSliceOpInterface
952 :
public BufferizableOpInterface::ExternalModel<
953 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
961 return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
966 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
967 return opOperand == parallelInsertSliceOp.getDestMutable();
974 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
975 ParallelCombiningOpInterface parallelCombiningParent =
976 parallelInsertSliceOp.getParallelCombiningParent();
982 FailureOr<Value> destBuffer =
984 if (failed(destBuffer))
986 FailureOr<Value> srcBuffer =
988 if (failed(srcBuffer))
992 auto destBufferType = cast<MemRefType>(destBuffer->getType());
993 MemRefType subviewMemRefType =
994 memref::SubViewOp::inferRankReducedResultType(
995 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
996 parallelInsertSliceOp.getMixedOffsets(),
997 parallelInsertSliceOp.getMixedSizes(),
998 parallelInsertSliceOp.getMixedStrides());
999 Value subview = rewriter.
create<memref::SubViewOp>(
1000 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
1001 parallelInsertSliceOp.getMixedOffsets(),
1002 parallelInsertSliceOp.getMixedSizes(),
1003 parallelInsertSliceOp.getMixedStrides());
1006 if (failed(
options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
1007 *srcBuffer, subview)))
1018 if (hasEffect<MemoryEffects::Free>(user)) {
1019 if (user->getBlock() == parallelCombiningParent->getBlock())
1020 rewriter.
moveOpBefore(user, user->getBlock()->getTerminator());
1042 struct SplatOpInterface
1043 :
public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1046 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
1052 auto splatOp = cast<tensor::SplatOp>(op);
1057 rewriter, loc, splatOp.getResult(),
options, state,
1059 if (failed(tensorAlloc))
1063 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1067 return op->
emitError(
"memory space not implemented yet");
1072 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1076 rewriter.
create<linalg::YieldOp>(loc, splatOp.getInput());
1077 rewriter.
replaceOp(splatOp, linalgOp.getResult()[0]);
1086 struct ConcatOpInterface
1087 :
public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1090 bool bufferizesToAllocation(
Operation *op,
Value value)
const {
return true; }
1111 auto concatOp = cast<tensor::ConcatOp>(op);
1116 rewriter, loc, concatOp.getResult(),
options, state,
1118 if (failed(tensorAlloc))
1120 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1124 return op->
emitError(
"memory space not implemented yet");
1126 MemRefLayoutAttrInterface layout;
1127 MemRefType memrefType =
1129 concatOp.getResultType().getElementType(), layout);
1130 Value dstBuffer = rewriter.
create<bufferization::ToBufferOp>(
1131 op->
getLoc(), memrefType, *tensorAlloc);
1134 uint64_t concatDim = concatOp.getDim();
1135 bool dynamicConcatDim =
false;
1143 for (
const auto &[dimIdx, dimSize] :
1145 if (dimSize == ShapedType::kDynamic) {
1146 auto dimOp = rewriter.
create<memref::DimOp>(loc, dstBuffer, dimIdx);
1147 sizes.push_back(dimOp.getResult());
1148 if (dimIdx == concatDim)
1149 dynamicConcatDim =
true;
1155 int64_t concatDimOffset = 0;
1156 std::optional<Value> dynamicOffset;
1157 std::optional<Value> dynamicSize;
1158 if (dynamicConcatDim) {
1161 dynamicOffset = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1164 for (
auto operand : concatOp.getInputs()) {
1167 if (failed(srcBuffer))
1173 auto operandTensorType = cast<RankedTensorType>(operand.getType());
1174 int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1176 if (dynamicConcatDim) {
1177 offsets[concatDim] = dynamicOffset.value();
1178 dynamicSize = rewriter.
create<memref::DimOp>(loc, *srcBuffer, concatDim)
1180 sizes[concatDim] = dynamicSize.value();
1182 sizes[concatDim] = rewriter.
getIndexAttr(operandConcatDimSize);
1183 offsets[concatDim] = rewriter.
getIndexAttr(concatDimOffset);
1187 auto dstMemrefType = cast<MemRefType>(memrefType);
1188 MemRefType subviewMemRefType =
1189 memref::SubViewOp::inferRankReducedResultType(
1190 operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1192 Value subview = rewriter.
create<memref::SubViewOp>(
1193 loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1196 if (failed(
options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1199 if (dynamicConcatDim) {
1200 dynamicOffset = rewriter.
create<arith::AddIOp>(
1201 loc, dynamicOffset.value(), dynamicSize.value());
1203 concatDimOffset += operandConcatDimSize;
1219 CastOp::attachInterface<CastOpInterface>(*ctx);
1220 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1221 ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1222 DimOp::attachInterface<DimOpInterface>(*ctx);
1223 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1224 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1225 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1226 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1227 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1228 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1229 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1230 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1231 PadOp::attachInterface<PadOpInterface>(*ctx);
1232 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1234 RankOp::attachInterface<RankOpInterface>(*ctx);
1235 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1236 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1239 ctx->
loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...