34 :
public BufferizableOpInterface::ExternalModel<CastOpInterface,
36 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
37 const AnalysisState &state)
const {
41 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
42 const AnalysisState &state)
const {
46 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47 const AnalysisState &state)
const {
48 return {{op->
getResult(0), BufferRelation::Equivalent}};
51 FailureOr<BufferLikeType>
53 const BufferizationState &state,
54 SmallVector<Value> &invocationStack)
const {
55 auto castOp = cast<tensor::CastOp>(op);
56 auto maybeSrcBufferType =
57 bufferization::detail::asMemRefType(bufferization::getBufferType(
58 castOp.getSource(),
options, state, invocationStack));
59 if (
failed(maybeSrcBufferType))
61 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
67 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
70 return cast<BufferLikeType>(
71 getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
75 if (isa<UnrankedTensorType>(castOp.getType())) {
76 return cast<BufferLikeType>(
77 getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
82 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
83 return cast<BufferLikeType>(MemRefType::get(
84 rankedResultType.getShape(), rankedResultType.getElementType(),
85 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
88 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
89 const BufferizationOptions &
options,
90 BufferizationState &state)
const {
91 auto castOp = cast<tensor::CastOp>(op);
94 FailureOr<Value> resultBuffer =
95 getBuffer(rewriter, castOp.getSource(),
options, state);
100 auto resultMemRefType =
101 bufferization::getBufferType(castOp.getResult(),
options, state);
102 if (
failed(resultMemRefType))
104 if (resultBuffer->getType() == *resultMemRefType) {
106 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
111 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
112 *resultMemRefType) &&
113 "CallOp::bufferize: cast incompatible");
114 replaceOpWithNewBufferizedOp<memref::CastOp>(
115 rewriter, op, *resultMemRefType, *resultBuffer);
122struct CollapseShapeOpInterface
123 :
public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
124 tensor::CollapseShapeOp> {
125 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
126 const AnalysisState &state)
const {
134 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
135 const AnalysisState &state)
const {
139 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
140 const AnalysisState &state)
const {
142 return {{op->
getOpResult(0), BufferRelation::Equivalent}};
145 FailureOr<BufferLikeType>
147 const BufferizationState &state,
148 SmallVector<Value> &invocationStack)
const {
149 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
150 auto maybeSrcBufferType = bufferization::getBufferType(
151 collapseShapeOp.getSrc(),
options, state, invocationStack);
152 if (
failed(maybeSrcBufferType))
154 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
155 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
156 srcBufferType, collapseShapeOp.getReassociationIndices());
158 if (!canBeCollapsed) {
160 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
161 return cast<BufferLikeType>(
162 bufferization::getMemRefTypeWithStaticIdentityLayout(
163 tensorResultType, srcBufferType.getMemorySpace()));
166 return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
167 srcBufferType, collapseShapeOp.getReassociationIndices()));
170 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
171 const BufferizationOptions &
options,
172 BufferizationState &state)
const {
173 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
174 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
175 FailureOr<Value> maybeBuffer =
176 getBuffer(rewriter, collapseShapeOp.getSrc(),
options, state);
179 Value buffer = *maybeBuffer;
180 auto bufferType = cast<MemRefType>(buffer.
getType());
182 if (tensorResultType.getRank() == 0) {
184 MemRefType resultType;
186 if (bufferType.getLayout().isIdentity()) {
188 MemRefLayoutAttrInterface layout;
189 resultType = MemRefType::get({}, tensorResultType.getElementType(),
190 layout, bufferType.getMemorySpace());
194 SmallVector<int64_t> strides;
196 if (
failed(bufferType.getStridesAndOffset(strides, offset)))
198 resultType = MemRefType::get(
199 {}, tensorResultType.getElementType(),
200 StridedLayoutAttr::get(op->
getContext(), offset, {}),
201 bufferType.getMemorySpace());
204 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
205 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
212 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
213 bufferType, collapseShapeOp.getReassociationIndices());
214 if (!canBeCollapsed) {
216 AnalysisState analysisState(
options);
217 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
218 rewriter, op->
getLoc(), collapseShapeOp.getSrc(),
options, state);
222 MemRefType::get(collapseShapeOp.getSrcType().getShape(),
223 collapseShapeOp.getSrcType().getElementType(),
224 AffineMap(), bufferType.getMemorySpace());
225 buffer = bufferization::ToBufferOp::create(rewriter, op->
getLoc(),
226 memrefType, *tensorAlloc);
230 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
231 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
238 :
public BufferizableOpInterface::ExternalModel<DimOpInterface,
240 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
241 const AnalysisState &state)
const {
246 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
247 const AnalysisState &state)
const {
251 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
252 const AnalysisState &state)
const {
256 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
257 const BufferizationOptions &
options,
258 BufferizationState &state)
const {
259 auto dimOp = cast<tensor::DimOp>(op);
260 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(),
options, state);
263 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
270struct EmptyOpInterface
271 :
public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
273 bool bufferizesToAllocation(Operation *op, Value value)
const {
return true; }
275 bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
276 const AnalysisState &state)
const {
281 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
282 const BufferizationOptions &
options,
283 BufferizationState &state)
const {
284 auto emptyOp = cast<tensor::EmptyOp>(op);
293 FailureOr<Value> allocTensor = allocateTensorForShapedValue(
304struct ExpandShapeOpInterface
305 :
public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
306 tensor::ExpandShapeOp> {
307 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
308 const AnalysisState &state)
const {
314 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
315 const AnalysisState &state)
const {
319 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
320 const AnalysisState &state)
const {
321 return {{op->
getOpResult(0), BufferRelation::Equivalent}};
324 FailureOr<BufferLikeType>
326 const BufferizationState &state,
327 SmallVector<Value> &invocationStack)
const {
328 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
329 auto maybeSrcBufferType = bufferization::getBufferType(
330 expandShapeOp.getSrc(),
options, state, invocationStack);
331 if (
failed(maybeSrcBufferType))
333 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
334 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
335 srcBufferType, expandShapeOp.getResultType().getShape(),
336 expandShapeOp.getReassociationIndices());
337 if (
failed(maybeResultType))
339 return cast<BufferLikeType>(*maybeResultType);
342 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
343 const BufferizationOptions &
options,
344 BufferizationState &state)
const {
345 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
346 auto tensorResultType = expandShapeOp.getResultType();
347 FailureOr<Value> buffer =
348 getBuffer(rewriter, expandShapeOp.getSrc(),
options, state);
352 auto memrefExpandShape = memref::ExpandShapeOp::create(
353 rewriter, op->
getLoc(), tensorResultType.getShape(), *buffer,
354 expandShapeOp.getReassociationIndices(),
355 expandShapeOp.getMixedOutputShape());
356 replaceOpWithBufferizedValues(rewriter, op,
357 memrefExpandShape->getResults());
363struct ExtractSliceOpInterface
364 :
public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
365 tensor::ExtractSliceOp> {
366 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
367 const AnalysisState &state)
const {
371 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
372 const AnalysisState &state)
const {
376 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
377 const AnalysisState &state)
const {
378 return {{op->
getOpResult(0), BufferRelation::Unknown}};
381 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
382 const BufferizationOptions &
options,
383 BufferizationState &state)
const {
384 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
385 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
386 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
387 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
388 Location loc = extractSliceOp.getLoc();
391 FailureOr<Value> srcMemref =
392 getBuffer(rewriter, extractSliceOp.getSource(),
options, state);
397 auto resultMemrefType = bufferization::getBufferType(
398 extractSliceOp.getResult(),
options, state);
399 if (
failed(resultMemrefType))
401 Value subView = memref::SubViewOp::create(
402 rewriter, loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
403 mixedOffsets, mixedSizes, mixedStrides);
405 replaceOpWithBufferizedValues(rewriter, op, subView);
409 FailureOr<BufferLikeType>
411 const BufferizationState &state,
412 SmallVector<Value> &invocationStack)
const {
413 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
414 assert(value == extractSliceOp.getResult() &&
"invalid value");
415 auto srcMemrefType = bufferization::getBufferType(
416 extractSliceOp.getSource(),
options, state, invocationStack);
417 if (
failed(srcMemrefType))
419 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
420 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
421 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
422 return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
423 extractSliceOp.getType().getShape(),
424 llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
430struct ExtractOpInterface
431 :
public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
433 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
434 const AnalysisState &state)
const {
438 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
439 const AnalysisState &state)
const {
443 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
444 const AnalysisState &state)
const {
448 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
449 const BufferizationOptions &
options,
450 BufferizationState &state)
const {
451 auto extractOp = cast<tensor::ExtractOp>(op);
452 FailureOr<Value> srcMemref =
453 getBuffer(rewriter, extractOp.getTensor(),
options, state);
456 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
457 extractOp.getIndices());
464static void createStores(RewriterBase &rewriter, Location loc,
int dim,
465 Value buffer, ArrayRef<int64_t> shape,
466 ArrayRef<Value> constants,
467 OperandRange::iterator &elementIt,
468 SmallVectorImpl<Value> &
indices) {
469 if (dim ==
static_cast<int>(shape.size()) - 1) {
470 for (
int i = 0; i < shape.back(); ++i) {
472 memref::StoreOp::create(rewriter, loc, *elementIt, buffer,
indices);
477 for (
int i = 0; i < shape[dim]; ++i) {
479 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
485struct FromElementsOpInterface
486 :
public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
487 tensor::FromElementsOp> {
489 bool bufferizesToAllocation(Operation *op, Value value)
const {
return true; }
491 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
492 const BufferizationOptions &
options,
493 BufferizationState &state)
const {
494 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
495 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
498 Location loc = op->
getLoc();
499 auto shape = tensorType.getShape();
501 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
502 rewriter, loc, fromElementsOp.getResult(),
options, state,
506 FailureOr<BufferLikeType> memrefType =
507 bufferization::getBufferType(*tensorAlloc,
options, state);
510 Value buffer = bufferization::ToBufferOp::create(rewriter, op->
getLoc(),
511 *memrefType, *tensorAlloc);
514 if (fromElementsOp.getElements().empty()) {
515 replaceOpWithBufferizedValues(rewriter, op, buffer);
521 memref::StoreOp::create(rewriter, loc,
522 fromElementsOp.getElements().front(), buffer);
523 replaceOpWithBufferizedValues(rewriter, op, buffer);
528 auto maxDim = *llvm::max_element(shape);
529 SmallVector<Value, 2> constants;
530 constants.reserve(maxDim);
531 for (
int i = 0; i < maxDim; ++i)
535 auto elementIt = fromElementsOp.getElements().begin();
536 SmallVector<Value, 2>
indices(tensorType.getRank(), constants[0]);
537 createStores(rewriter, loc, 0, buffer, shape, constants, elementIt,
540 replaceOpWithBufferizedValues(rewriter, op, buffer);
567static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
568 Value tensorDestination,
570 Region &generateBody) {
571 assert(generateBody.
hasOneBlock() &&
"expected body with single block");
572 auto tensorType = cast<RankedTensorType>(tensorDestination.
getType());
577 OpBuilder::InsertionGuard g(rewriter);
579 linalg::MapOp::create(rewriter, loc, tensorType,
ValueRange(),
581 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
582 linalgBody.
addArgument(tensorType.getElementType(), loc);
587 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
588 indices.push_back(linalg::IndexOp::create(rewriter, loc, dim));
592 auto yieldOp = cast<tensor::YieldOp>(linalgBody.
getTerminator());
595 return linalgOp.getResult()[0];
599struct GenerateOpInterface
600 :
public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
601 tensor::GenerateOp> {
603 bool bufferizesToAllocation(Operation *op, Value value)
const {
return true; }
605 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
606 const BufferizationOptions &
options,
607 BufferizationState &state)
const {
608 auto generateOp = cast<tensor::GenerateOp>(op);
610 auto type = generateOp.getResult().getType();
613 if (
options.defaultMemorySpaceFn(type) != Attribute())
614 return op->
emitError(
"memory space not implemented yet");
617 Location loc = op->
getLoc();
618 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
619 rewriter, loc, generateOp.getResult(),
options, state,
624 Value
result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
625 generateOp.getDynamicExtents(),
626 generateOp.getBody());
637struct InsertOpInterface
638 :
public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
640 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
641 const BufferizationOptions &
options,
642 BufferizationState &state)
const {
643 auto insertOp = cast<tensor::InsertOp>(op);
644 FailureOr<Value> destMemref =
645 getBuffer(rewriter, insertOp.getDest(),
options, state);
648 memref::StoreOp::create(rewriter, insertOp.getLoc(), insertOp.getScalar(),
649 *destMemref, insertOp.getIndices());
650 replaceOpWithBufferizedValues(rewriter, op, *destMemref);
655template <
typename InsertOpTy>
656static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
657 OpOperand &opOperand) {
659 if (opOperand == insertSliceOp.getSourceMutable())
663 assert(opOperand == insertSliceOp.getDestMutable() &&
"expected dest");
667 bool allOffsetsZero =
668 llvm::all_of(insertSliceOp.getMixedOffsets(),
isZeroInteger);
669 RankedTensorType destType = insertSliceOp.getDestType();
670 bool sizesMatchDestSizes =
674 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
682struct InsertSliceOpInterface
683 :
public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
684 tensor::InsertSliceOp> {
685 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
686 const AnalysisState &state)
const {
687 return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
691 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
692 const BufferizationOptions &
options,
693 BufferizationState &state)
const {
699 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
700 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
701 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
702 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
703 Location loc = insertSliceOp.getLoc();
706 FailureOr<Value> dstMemref =
707 getBuffer(rewriter, insertSliceOp.getDest(),
options, state);
712 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
713 MemRefType subviewMemRefType =
714 memref::SubViewOp::inferRankReducedResultType(
715 insertSliceOp.getSourceType().getShape(), dstMemrefType,
716 mixedOffsets, mixedSizes, mixedStrides);
718 memref::SubViewOp::create(rewriter, loc, subviewMemRefType, *dstMemref,
719 mixedOffsets, mixedSizes, mixedStrides);
723 FailureOr<Value> srcMemref =
724 getBuffer(rewriter, insertSliceOp.getSource(),
options, state);
727 if (
failed(
options.createMemCpy(rewriter, loc, *srcMemref, subView)))
730 replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
740 :
public BufferizableOpInterface::ExternalModel<PadOpInterface,
742 bool bufferizesToAllocation(Operation *op, Value value)
const {
return true; }
744 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
745 const AnalysisState &state)
const {
749 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
750 const AnalysisState &state)
const {
754 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
755 const AnalysisState &state)
const {
759 FailureOr<BufferLikeType>
761 const BufferizationState &state,
762 SmallVector<Value> &invocationStack)
const {
764 auto padOp = cast<tensor::PadOp>(op);
765 auto maybeSrcBufferType =
766 bufferization::detail::asMemRefType(bufferization::getBufferType(
767 padOp.getSource(),
options, state, invocationStack));
768 if (
failed(maybeSrcBufferType))
770 MemRefLayoutAttrInterface layout;
771 return cast<BufferLikeType>(
772 MemRefType::get(padOp.getResultType().getShape(),
773 padOp.getResultType().getElementType(), layout,
774 maybeSrcBufferType->getMemorySpace()));
777 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
778 const BufferizationOptions &
options,
779 BufferizationState &state)
const {
780 auto padOp = cast<tensor::PadOp>(op);
781 Location loc = padOp.getLoc();
782 RankedTensorType resultType = padOp.getResultType();
783 RankedTensorType srcType = padOp.getSourceType();
785 auto toValue = [&](OpFoldResult ofr) {
786 if (
auto value = dyn_cast<Value>(ofr))
794 SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
795 SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
796 SmallVector<Value> dynamicSizes;
797 for (int64_t i = 0; i < resultType.getRank(); ++i) {
798 if (!resultType.isDynamicDim(i))
800 Value srcDim = tensor::DimOp::create(rewriter, loc, padOp.getSource(), i);
801 Value lowPad = toValue(mixedLowPad[i]);
802 Value highPad = toValue(mixedHighPad[i]);
803 AffineExpr s0, s1, s2;
805 AffineExpr sumExpr = s0 + s1 + s2;
806 Value sum = affine::AffineApplyOp::create(
807 rewriter, loc, sumExpr,
ValueRange{srcDim, lowPad, highPad});
808 dynamicSizes.push_back(sum);
812 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
813 rewriter, loc, padOp.getResult(),
options, state,
821 Value filledBuffer = lowerGenerateLikeOpBody(
822 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
825 SmallVector<OpFoldResult> sliceSizes =
827 SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
830 padOp, padOp.getSource(), filledBuffer,
831 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
838struct RankOpInterface
839 :
public BufferizableOpInterface::ExternalModel<RankOpInterface,
841 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
842 const AnalysisState &state)
const {
847 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
848 const AnalysisState &state)
const {
852 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
853 const AnalysisState &state)
const {
857 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
858 const BufferizationOptions &
options,
859 BufferizationState &state)
const {
860 auto rankOp = cast<tensor::RankOp>(op);
862 getBuffer(rewriter, rankOp.getTensor(),
options, state);
865 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
872struct ReshapeOpInterface
873 :
public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
875 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
876 const AnalysisState &state)
const {
878 auto reshapeOp = cast<tensor::ReshapeOp>(op);
879 return opOperand == reshapeOp.getShapeMutable();
882 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
883 const AnalysisState &state)
const {
887 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
888 const AnalysisState &state)
const {
890 auto reshapeOp = cast<tensor::ReshapeOp>(op);
891 if (reshapeOp.getSourceMutable() != opOperand)
893 return {{op->
getOpResult(0), BufferRelation::Equivalent}};
896 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
897 const BufferizationOptions &
options,
898 BufferizationState &state)
const {
899 auto reshapeOp = cast<tensor::ReshapeOp>(op);
900 FailureOr<Value> srcBuffer =
901 getBuffer(rewriter, reshapeOp.getSource(),
options, state);
902 FailureOr<Value> shapeBuffer =
903 getBuffer(rewriter, reshapeOp.getShape(),
options, state);
906 auto maybeResultMemRefType =
907 bufferization::getBufferType(reshapeOp.getResult(),
options, state);
908 if (
failed(maybeResultMemRefType))
914 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
915 if (srcType && !srcType.getLayout().isIdentity()) {
916 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
917 rewriter, op->
getLoc(), reshapeOp.getSource(),
options, state);
920 auto memrefType = MemRefType::get(
921 srcType.getShape(), srcType.getElementType(), AffineMap(),
922 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
923 srcBuffer = bufferization::ToBufferOp::create(rewriter, op->
getLoc(),
924 memrefType, *tensorAlloc)
928 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
929 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
933 FailureOr<BufferLikeType>
935 const BufferizationState &state,
936 SmallVector<Value> &invocationStack)
const {
937 auto reshapeOp = cast<tensor::ReshapeOp>(op);
938 assert(value == reshapeOp.getResult() &&
"unexpected value provided");
939 auto maybeSourceBufferType = bufferization::getBufferType(
940 reshapeOp.getSource(),
options, state, invocationStack);
941 if (
failed(maybeSourceBufferType))
943 return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
944 reshapeOp.getResult().getType(),
945 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
950struct ParallelInsertSliceOpInterface
951 :
public BufferizableOpInterface::ExternalModel<
952 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
953 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
954 const AnalysisState &state)
const {
958 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
959 const AnalysisState &state)
const {
960 return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
963 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
964 const AnalysisState &state)
const {
965 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
966 return opOperand == parallelInsertSliceOp.getDestMutable();
969 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
970 const BufferizationOptions &
options,
971 BufferizationState &state)
const {
972 OpBuilder::InsertionGuard g(rewriter);
973 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
974 InParallelOpInterface parallelCombiningParent =
975 parallelInsertSliceOp.getParallelCombiningParent();
981 FailureOr<Value> destBuffer =
982 getBuffer(rewriter, parallelInsertSliceOp.getDest(),
options, state);
985 FailureOr<Value> srcBuffer =
986 getBuffer(rewriter, parallelInsertSliceOp.getSource(),
options, state);
991 auto destBufferType = cast<MemRefType>(destBuffer->getType());
992 MemRefType subviewMemRefType =
993 memref::SubViewOp::inferRankReducedResultType(
994 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
995 parallelInsertSliceOp.getMixedOffsets(),
996 parallelInsertSliceOp.getMixedSizes(),
997 parallelInsertSliceOp.getMixedStrides());
998 Value subview = memref::SubViewOp::create(
999 rewriter, parallelInsertSliceOp.getLoc(), subviewMemRefType,
1000 *destBuffer, parallelInsertSliceOp.getMixedOffsets(),
1001 parallelInsertSliceOp.getMixedSizes(),
1002 parallelInsertSliceOp.getMixedStrides());
1005 if (
failed(
options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
1006 *srcBuffer, subview)))
1016 for (Operation *user : srcBuffer->getUsers()) {
1018 if (user->getBlock() == parallelCombiningParent->getBlock())
1019 rewriter.
moveOpBefore(user, user->getBlock()->getTerminator());
1032 resolveConflicts(Operation *op, RewriterBase &rewriter,
1033 const AnalysisState &analysisState,
1034 const BufferizationState &bufferizationState)
const {
1041struct SplatOpInterface
1042 :
public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1045 bool bufferizesToAllocation(Operation *op, Value value)
const {
return true; }
1047 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1048 const BufferizationOptions &
options,
1049 BufferizationState &state)
const {
1050 OpBuilder::InsertionGuard g(rewriter);
1051 auto splatOp = cast<tensor::SplatOp>(op);
1054 Location loc = op->
getLoc();
1055 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1056 rewriter, loc, splatOp.getResult(),
options, state,
1062 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1065 if (
options.defaultMemorySpaceFn(tensorType) != Attribute())
1066 return op->
emitError(
"memory space not implemented yet");
1068 auto linalgOp = linalg::MapOp::create(rewriter, loc, tensorType,
1071 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1072 linalgBody.
addArgument(tensorType.getElementType(), loc);
1076 linalg::YieldOp::create(rewriter, loc, splatOp.getInput());
1077 rewriter.
replaceOp(splatOp, linalgOp.getResult()[0]);
1086struct ConcatOpInterface
1087 :
public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1090 bool bufferizesToAllocation(Operation *op, Value value)
const {
return true; }
1092 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1093 const AnalysisState &state)
const {
1097 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1098 const AnalysisState &state)
const {
1102 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1103 const AnalysisState &state)
const {
1107 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1108 const BufferizationOptions &
options,
1109 BufferizationState &state)
const {
1110 OpBuilder::InsertionGuard g(rewriter);
1111 auto concatOp = cast<tensor::ConcatOp>(op);
1114 Location loc = op->
getLoc();
1115 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1116 rewriter, loc, concatOp.getResult(),
options, state,
1120 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1123 if (
options.defaultMemorySpaceFn(tensorType) != Attribute())
1124 return op->
emitError(
"memory space not implemented yet");
1126 MemRefLayoutAttrInterface layout;
1127 MemRefType memrefType =
1128 MemRefType::get(concatOp.getResultType().getShape(),
1129 concatOp.getResultType().getElementType(), layout);
1130 Value dstBuffer = bufferization::ToBufferOp::create(
1131 rewriter, op->
getLoc(), memrefType, *tensorAlloc);
1134 uint64_t concatDim = concatOp.getDim();
1135 bool dynamicConcatDim =
false;
1137 SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1139 SmallVector<OpFoldResult> strides(tensorType.getRank(),
1141 SmallVector<OpFoldResult> sizes;
1143 for (
const auto &[dimIdx, dimSize] :
1144 llvm::enumerate(tensorType.getShape())) {
1145 if (dimSize == ShapedType::kDynamic) {
1146 auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
1147 sizes.push_back(dimOp.getResult());
1148 if (dimIdx == concatDim)
1149 dynamicConcatDim =
true;
1155 int64_t concatDimOffset = 0;
1156 std::optional<Value> dynamicOffset;
1157 std::optional<Value> dynamicSize;
1158 if (dynamicConcatDim) {
1164 for (
auto operand : concatOp.getInputs()) {
1166 FailureOr<Value> srcBuffer = getBuffer(rewriter, operand,
options, state);
1173 auto operandTensorType = cast<RankedTensorType>(operand.getType());
1174 int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1176 if (dynamicConcatDim) {
1177 offsets[concatDim] = dynamicOffset.value();
1179 memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
1181 sizes[concatDim] = dynamicSize.value();
1183 sizes[concatDim] = rewriter.
getIndexAttr(operandConcatDimSize);
1184 offsets[concatDim] = rewriter.
getIndexAttr(concatDimOffset);
1188 auto dstMemrefType = cast<MemRefType>(memrefType);
1189 MemRefType subviewMemRefType =
1190 memref::SubViewOp::inferRankReducedResultType(
1191 operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1193 Value subview = memref::SubViewOp::create(
1194 rewriter, loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1197 if (
failed(
options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1200 if (dynamicConcatDim) {
1201 dynamicOffset = arith::AddIOp::create(
1202 rewriter, loc, dynamicOffset.value(), dynamicSize.value());
1204 concatDimOffset += operandConcatDimSize;
1208 replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1220 CastOp::attachInterface<CastOpInterface>(*ctx);
1221 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1222 ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1223 DimOp::attachInterface<DimOpInterface>(*ctx);
1224 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1225 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1226 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1227 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1228 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1229 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1230 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1231 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1232 PadOp::attachInterface<PadOpInterface>(*ctx);
1233 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1235 RankOp::attachInterface<RankOpInterface>(*ctx);
1236 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1237 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1240 ctx->
loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasEffect< MemoryEffects::Free >(Operation *)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumArguments()
bool hasOneBlock()
Return true if this region has exactly one block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .