MLIR 22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24
25using namespace mlir;
26using namespace mlir::bufferization;
27using namespace mlir::tensor;
28
29namespace mlir {
30namespace tensor {
31namespace {
32
33struct CastOpInterface
34 : public BufferizableOpInterface::ExternalModel<CastOpInterface,
35 tensor::CastOp> {
36 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
37 const AnalysisState &state) const {
38 return false;
39 }
40
41 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
42 const AnalysisState &state) const {
43 return false;
44 }
45
46 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47 const AnalysisState &state) const {
48 return {{op->getResult(0), BufferRelation::Equivalent}};
49 }
50
51 FailureOr<BufferLikeType>
52 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
53 const BufferizationState &state,
54 SmallVector<Value> &invocationStack) const {
55 auto castOp = cast<tensor::CastOp>(op);
56 auto maybeSrcBufferType =
57 bufferization::detail::asMemRefType(bufferization::getBufferType(
58 castOp.getSource(), options, state, invocationStack));
59 if (failed(maybeSrcBufferType))
60 return failure();
61 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
62
63 // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
64 // type in case the input is an unranked tensor type.
65
66 // Case 1: Casting an unranked tensor
67 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
68 // When casting to a ranked tensor, we cannot infer any static offset or
69 // strides from the source. Assume fully dynamic.
70 return cast<BufferLikeType>(
71 getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
72 }
73
74 // Case 2: Casting to an unranked tensor type
75 if (isa<UnrankedTensorType>(castOp.getType())) {
76 return cast<BufferLikeType>(
77 getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
78 }
79
80 // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
81 // change.
82 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
83 return cast<BufferLikeType>(MemRefType::get(
84 rankedResultType.getShape(), rankedResultType.getElementType(),
85 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
86 }
87
88 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
89 const BufferizationOptions &options,
90 BufferizationState &state) const {
91 auto castOp = cast<tensor::CastOp>(op);
92
93 // The result buffer still has the old (pre-cast) type.
94 FailureOr<Value> resultBuffer =
95 getBuffer(rewriter, castOp.getSource(), options, state);
96 if (failed(resultBuffer))
97 return failure();
98
99 // Compute the new type.
100 auto resultMemRefType =
101 bufferization::getBufferType(castOp.getResult(), options, state);
102 if (failed(resultMemRefType))
103 return failure();
104 if (resultBuffer->getType() == *resultMemRefType) {
105 // This cast is a no-op.
106 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
107 return success();
108 }
109
110 // Replace the op with a memref.cast.
111 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
112 *resultMemRefType) &&
113 "CallOp::bufferize: cast incompatible");
114 replaceOpWithNewBufferizedOp<memref::CastOp>(
115 rewriter, op, *resultMemRefType, *resultBuffer);
116
117 return success();
118 }
119};
120
121/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
122struct CollapseShapeOpInterface
123 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
124 tensor::CollapseShapeOp> {
125 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
126 const AnalysisState &state) const {
127 // tensor.collapse_shape may reallocate, at which point the source buffer is
128 // copied. I.e., there will be a memory read side effect on the bufferized
129 // source. This function conservatively returns "true" because whether a
130 // copy will be created or not is not known at this point.
131 return true;
132 }
133
134 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
135 const AnalysisState &state) const {
136 return false;
137 }
138
139 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
140 const AnalysisState &state) const {
141 // TODO: CollapseShapeOp may allocate at runtime.
142 return {{op->getOpResult(0), BufferRelation::Equivalent}};
143 }
144
145 FailureOr<BufferLikeType>
146 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
147 const BufferizationState &state,
148 SmallVector<Value> &invocationStack) const {
149 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
150 auto maybeSrcBufferType = bufferization::getBufferType(
151 collapseShapeOp.getSrc(), options, state, invocationStack);
152 if (failed(maybeSrcBufferType))
153 return failure();
154 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
155 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
156 srcBufferType, collapseShapeOp.getReassociationIndices());
157
158 if (!canBeCollapsed) {
159 // If dims cannot be collapsed, this op bufferizes to a new allocation.
160 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
161 return cast<BufferLikeType>(
162 bufferization::getMemRefTypeWithStaticIdentityLayout(
163 tensorResultType, srcBufferType.getMemorySpace()));
164 }
165
166 return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
167 srcBufferType, collapseShapeOp.getReassociationIndices()));
168 }
169
170 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
171 const BufferizationOptions &options,
172 BufferizationState &state) const {
173 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
174 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
175 FailureOr<Value> maybeBuffer =
176 getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
177 if (failed(maybeBuffer))
178 return failure();
179 Value buffer = *maybeBuffer;
180 auto bufferType = cast<MemRefType>(buffer.getType());
181
182 if (tensorResultType.getRank() == 0) {
183 // 0-d collapses must go through a different op builder.
184 MemRefType resultType;
185
186 if (bufferType.getLayout().isIdentity()) {
187 // Standard layout: result type has no offset.
188 MemRefLayoutAttrInterface layout;
189 resultType = MemRefType::get({}, tensorResultType.getElementType(),
190 layout, bufferType.getMemorySpace());
191 } else {
192 // Source memref has a layout map: result type has the same offset as
193 // the source type.
194 SmallVector<int64_t> strides;
195 int64_t offset;
196 if (failed(bufferType.getStridesAndOffset(strides, offset)))
197 return failure();
198 resultType = MemRefType::get(
199 {}, tensorResultType.getElementType(),
200 StridedLayoutAttr::get(op->getContext(), offset, {}),
201 bufferType.getMemorySpace());
202 }
203
204 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
205 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
206 return success();
207 }
208
209 // If the dims are not collapsible (due to an incompatible source layout
210 // map), force an out-of-place bufferization, i.e., a buffer copy. This
211 // newly allocated buffer will have no layout map and thus be collapsible.
212 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
213 bufferType, collapseShapeOp.getReassociationIndices());
214 if (!canBeCollapsed) {
215 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
216 AnalysisState analysisState(options);
217 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
218 rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
219 if (failed(tensorAlloc))
220 return failure();
221 auto memrefType =
222 MemRefType::get(collapseShapeOp.getSrcType().getShape(),
223 collapseShapeOp.getSrcType().getElementType(),
224 AffineMap(), bufferType.getMemorySpace());
225 buffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
226 memrefType, *tensorAlloc);
227 }
228
229 // Result type is inferred by the builder.
230 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
231 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
232 return success();
233 }
234};
235
236/// Bufferization of tensor.dim. Replace with memref.dim.
237struct DimOpInterface
238 : public BufferizableOpInterface::ExternalModel<DimOpInterface,
239 tensor::DimOp> {
240 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
241 const AnalysisState &state) const {
242 // The op reads the tensor's metadata but not its contents.
243 return false;
244 }
245
246 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
247 const AnalysisState &state) const {
248 return false;
249 }
250
251 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
252 const AnalysisState &state) const {
253 return {};
254 }
255
256 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
257 const BufferizationOptions &options,
258 BufferizationState &state) const {
259 auto dimOp = cast<tensor::DimOp>(op);
260 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
261 if (failed(v))
262 return failure();
263 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
264 dimOp.getIndex());
265 return success();
266 }
267};
268
269/// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
270struct EmptyOpInterface
271 : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
272 tensor::EmptyOp> {
273 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
274
275 bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
276 const AnalysisState &state) const {
277 // The returned tensor does not have specified contents.
278 return false;
279 }
280
281 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
282 const BufferizationOptions &options,
283 BufferizationState &state) const {
284 auto emptyOp = cast<tensor::EmptyOp>(op);
285
286 // Optimization: Fold away the op if it has no uses.
287 if (op->getUses().empty()) {
288 rewriter.eraseOp(op);
289 return success();
290 }
291
292 // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
293 FailureOr<Value> allocTensor = allocateTensorForShapedValue(
294 rewriter, op->getLoc(), emptyOp.getResult(), options, state,
295 /*copy=*/false);
296 if (failed(allocTensor))
297 return failure();
298 rewriter.replaceOp(op, *allocTensor);
299 return success();
300 }
301};
302
303/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
304struct ExpandShapeOpInterface
305 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
306 tensor::ExpandShapeOp> {
307 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
308 const AnalysisState &state) const {
309 // In contrast to tensor.collapse_shape, this op can always be bufferized
310 // without a copy.
311 return false;
312 }
313
314 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
315 const AnalysisState &state) const {
316 return false;
317 }
318
319 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
320 const AnalysisState &state) const {
321 return {{op->getOpResult(0), BufferRelation::Equivalent}};
322 }
323
324 FailureOr<BufferLikeType>
325 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
326 const BufferizationState &state,
327 SmallVector<Value> &invocationStack) const {
328 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
329 auto maybeSrcBufferType = bufferization::getBufferType(
330 expandShapeOp.getSrc(), options, state, invocationStack);
331 if (failed(maybeSrcBufferType))
332 return failure();
333 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
334 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
335 srcBufferType, expandShapeOp.getResultType().getShape(),
336 expandShapeOp.getReassociationIndices());
337 if (failed(maybeResultType))
338 return failure();
339 return cast<BufferLikeType>(*maybeResultType);
340 }
341
342 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
343 const BufferizationOptions &options,
344 BufferizationState &state) const {
345 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
346 auto tensorResultType = expandShapeOp.getResultType();
347 FailureOr<Value> buffer =
348 getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
349 if (failed(buffer))
350 return failure();
351
352 auto memrefExpandShape = memref::ExpandShapeOp::create(
353 rewriter, op->getLoc(), tensorResultType.getShape(), *buffer,
354 expandShapeOp.getReassociationIndices(),
355 expandShapeOp.getMixedOutputShape());
356 replaceOpWithBufferizedValues(rewriter, op,
357 memrefExpandShape->getResults());
358 return success();
359 }
360};
361
362/// Bufferization of tensor.extract_slice. Replace with memref.subview.
363struct ExtractSliceOpInterface
364 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
365 tensor::ExtractSliceOp> {
366 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
367 const AnalysisState &state) const {
368 return false;
369 }
370
371 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
372 const AnalysisState &state) const {
373 return false;
374 }
375
376 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
377 const AnalysisState &state) const {
378 return {{op->getOpResult(0), BufferRelation::Unknown}};
379 }
380
381 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
382 const BufferizationOptions &options,
383 BufferizationState &state) const {
384 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
385 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
386 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
387 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
388 Location loc = extractSliceOp.getLoc();
389
390 // Get source buffer.
391 FailureOr<Value> srcMemref =
392 getBuffer(rewriter, extractSliceOp.getSource(), options, state);
393 if (failed(srcMemref))
394 return failure();
395
396 // Take a subview of the source buffer.
397 auto resultMemrefType = bufferization::getBufferType(
398 extractSliceOp.getResult(), options, state);
399 if (failed(resultMemrefType))
400 return failure();
401 Value subView = memref::SubViewOp::create(
402 rewriter, loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
403 mixedOffsets, mixedSizes, mixedStrides);
404
405 replaceOpWithBufferizedValues(rewriter, op, subView);
406 return success();
407 }
408
409 FailureOr<BufferLikeType>
410 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
411 const BufferizationState &state,
412 SmallVector<Value> &invocationStack) const {
413 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
414 assert(value == extractSliceOp.getResult() && "invalid value");
415 auto srcMemrefType = bufferization::getBufferType(
416 extractSliceOp.getSource(), options, state, invocationStack);
417 if (failed(srcMemrefType))
418 return failure();
419 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
420 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
421 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
422 return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
423 extractSliceOp.getType().getShape(),
424 llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
425 mixedStrides));
426 }
427};
428
429/// Bufferization of tensor.extract. Replace with memref.load.
430struct ExtractOpInterface
431 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
432 tensor::ExtractOp> {
433 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
434 const AnalysisState &state) const {
435 return true;
436 }
437
438 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
439 const AnalysisState &state) const {
440 return false;
441 }
442
443 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
444 const AnalysisState &state) const {
445 return {};
446 }
447
448 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
449 const BufferizationOptions &options,
450 BufferizationState &state) const {
451 auto extractOp = cast<tensor::ExtractOp>(op);
452 FailureOr<Value> srcMemref =
453 getBuffer(rewriter, extractOp.getTensor(), options, state);
454 if (failed(srcMemref))
455 return failure();
456 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
457 extractOp.getIndices());
458 return success();
459 }
460};
461
462// Implements backtracking to traverse indices of the output buffer while
463// iterating over op.elements().
464static void createStores(RewriterBase &rewriter, Location loc, int dim,
465 Value buffer, ArrayRef<int64_t> shape,
466 ArrayRef<Value> constants,
467 OperandRange::iterator &elementIt,
468 SmallVectorImpl<Value> &indices) {
469 if (dim == static_cast<int>(shape.size()) - 1) {
470 for (int i = 0; i < shape.back(); ++i) {
471 indices.back() = constants[i];
472 memref::StoreOp::create(rewriter, loc, *elementIt, buffer, indices);
473 ++elementIt;
474 }
475 return;
476 }
477 for (int i = 0; i < shape[dim]; ++i) {
478 indices[dim] = constants[i];
479 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
480 indices);
481 }
482}
483
484/// Bufferization of tensor.from_elements.
485struct FromElementsOpInterface
486 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
487 tensor::FromElementsOp> {
488
489 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
490
491 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
492 const BufferizationOptions &options,
493 BufferizationState &state) const {
494 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
495 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
496
497 // Allocate a buffer for the result.
498 Location loc = op->getLoc();
499 auto shape = tensorType.getShape();
500 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
501 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
502 rewriter, loc, fromElementsOp.getResult(), options, state,
503 /*copy=*/false);
504 if (failed(tensorAlloc))
505 return failure();
506 FailureOr<BufferLikeType> memrefType =
507 bufferization::getBufferType(*tensorAlloc, options, state);
508 if (failed(memrefType))
509 return failure();
510 Value buffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
511 *memrefType, *tensorAlloc);
512
513 // Case: tensor<0xelem_type>.
514 if (fromElementsOp.getElements().empty()) {
515 replaceOpWithBufferizedValues(rewriter, op, buffer);
516 return success();
517 }
518
519 // Case: tensor<elem_type>.
520 if (shape.empty()) {
521 memref::StoreOp::create(rewriter, loc,
522 fromElementsOp.getElements().front(), buffer);
523 replaceOpWithBufferizedValues(rewriter, op, buffer);
524 return success();
525 }
526
527 // Create constants for the range of possible indices [0, max{shape_i}).
528 auto maxDim = *llvm::max_element(shape);
529 SmallVector<Value, 2> constants;
530 constants.reserve(maxDim);
531 for (int i = 0; i < maxDim; ++i)
532 constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i));
533
534 // Traverse all `elements` and create `memref.store` ops.
535 auto elementIt = fromElementsOp.getElements().begin();
536 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
537 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
538 indices);
539
540 replaceOpWithBufferizedValues(rewriter, op, buffer);
541
542 return success();
543 }
544};
545
546/// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
547/// Such ops are lowered to linalg.map with the given tensor as a destination.
548///
549/// Example:
550/// ```
551/// %r = tensor.generate %x, %y {
552/// ^bb0(%arg0: index, %arg1: index):
553/// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
554/// tensor.yield %0 : index
555/// } : tensor<?x?xindex>
556/// ```
557///
558/// Is lowered to:
559/// ```
560/// linalg.map ins() outs(%dest) {
561/// %d0 = linalg.index 0 : index
562/// %d1 = linalg.index 1 : index
563/// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
564/// linalg.yield %0 : index
565/// }
566/// ```
567static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
568 Value tensorDestination,
569 ValueRange dynamicSizes,
570 Region &generateBody) {
571 assert(generateBody.hasOneBlock() && "expected body with single block");
572 auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
573 assert(generateBody.getNumArguments() == tensorType.getRank() &&
574 "rank mismatch");
575
576 // Create linalg::MapOp.
577 OpBuilder::InsertionGuard g(rewriter);
578 auto linalgOp =
579 linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
580 /*init=*/tensorDestination);
581 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
582 linalgBody.addArgument(tensorType.getElementType(), loc);
583
584 // Create linalg::IndexOps.
585 rewriter.setInsertionPointToStart(&linalgBody);
586 SmallVector<Value> indices;
587 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
588 indices.push_back(linalg::IndexOp::create(rewriter, loc, dim));
589
590 // Move over body.
591 rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
592 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
593 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
594
595 return linalgOp.getResult()[0];
596}
597
598/// Bufferization of tensor.generate.
599struct GenerateOpInterface
600 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
601 tensor::GenerateOp> {
602
603 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
604
605 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
606 const BufferizationOptions &options,
607 BufferizationState &state) const {
608 auto generateOp = cast<tensor::GenerateOp>(op);
609
610 auto type = generateOp.getResult().getType();
611
612 // TODO: Implement memory space for this op.
613 if (options.defaultMemorySpaceFn(type) != Attribute())
614 return op->emitError("memory space not implemented yet");
615
616 // Allocate memory.
617 Location loc = op->getLoc();
618 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
619 rewriter, loc, generateOp.getResult(), options, state,
620 /*copy=*/false);
621 if (failed(tensorAlloc))
622 return failure();
623
624 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
625 generateOp.getDynamicExtents(),
626 generateOp.getBody());
627 rewriter.replaceOp(generateOp, result);
628
629 return success();
630 }
631};
632
633/// Bufferization of tensor.insert. Replace with memref.store.
634///
635/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
636/// implementations for DestinationStyle ops.
637struct InsertOpInterface
638 : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
639 tensor::InsertOp> {
640 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
641 const BufferizationOptions &options,
642 BufferizationState &state) const {
643 auto insertOp = cast<tensor::InsertOp>(op);
644 FailureOr<Value> destMemref =
645 getBuffer(rewriter, insertOp.getDest(), options, state);
646 if (failed(destMemref))
647 return failure();
648 memref::StoreOp::create(rewriter, insertOp.getLoc(), insertOp.getScalar(),
649 *destMemref, insertOp.getIndices());
650 replaceOpWithBufferizedValues(rewriter, op, *destMemref);
651 return success();
652 }
653};
654
655template <typename InsertOpTy>
656static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
657 OpOperand &opOperand) {
658 // The source is always read.
659 if (opOperand == insertSliceOp.getSourceMutable())
660 return true;
661
662 // For the destination, it depends...
663 assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
664
665 // Dest is not read if it is entirely overwritten. E.g.:
666 // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
667 bool allOffsetsZero =
668 llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
669 RankedTensorType destType = insertSliceOp.getDestType();
670 bool sizesMatchDestSizes =
671 areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
672 bool allStridesOne =
673 areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
674 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
675}
676
677/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
678/// certain circumstances, this op can also be a no-op.
679///
680/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
681/// implementations for DestinationStyle ops.
682struct InsertSliceOpInterface
683 : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
684 tensor::InsertSliceOp> {
685 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
686 const AnalysisState &state) const {
687 return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
688 opOperand);
689 }
690
691 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
692 const BufferizationOptions &options,
693 BufferizationState &state) const {
694 // insert_slice ops arise from tiling and bufferizing them out-of-place is
695 // generally a deal breaker. When used with loops, this ends up cloning the
696 // whole tensor on every single iteration and is a symptom of a
697 // catastrophically bad scheduling decision.
698 // TODO: be very loud about it or even consider failing the pass.
699 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
700 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
701 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
702 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
703 Location loc = insertSliceOp.getLoc();
704
705 // Get destination buffer.
706 FailureOr<Value> dstMemref =
707 getBuffer(rewriter, insertSliceOp.getDest(), options, state);
708 if (failed(dstMemref))
709 return failure();
710
711 // Take a subview of the destination buffer.
712 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
713 MemRefType subviewMemRefType =
714 memref::SubViewOp::inferRankReducedResultType(
715 insertSliceOp.getSourceType().getShape(), dstMemrefType,
716 mixedOffsets, mixedSizes, mixedStrides);
717 Value subView =
718 memref::SubViewOp::create(rewriter, loc, subviewMemRefType, *dstMemref,
719 mixedOffsets, mixedSizes, mixedStrides);
720
721 // Copy tensor. If this tensor.insert_slice has a matching
722 // tensor.extract_slice, the copy operation will eventually fold away.
723 FailureOr<Value> srcMemref =
724 getBuffer(rewriter, insertSliceOp.getSource(), options, state);
725 if (failed(srcMemref))
726 return failure();
727 if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
728 return failure();
729
730 replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
731 return success();
732 }
733};
734
735/// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
736/// linalg.map + insert_slice.
737/// For best performance, vectorize before bufferization (better performance in
738/// case of padding with a constant).
739struct PadOpInterface
740 : public BufferizableOpInterface::ExternalModel<PadOpInterface,
741 tensor::PadOp> {
742 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
743
744 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
745 const AnalysisState &state) const {
746 return true;
747 }
748
749 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
750 const AnalysisState &state) const {
751 return false;
752 }
753
754 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
755 const AnalysisState &state) const {
756 return {};
757 }
758
759 FailureOr<BufferLikeType>
760 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
761 const BufferizationState &state,
762 SmallVector<Value> &invocationStack) const {
763 // Infer memory space from the source tensor.
764 auto padOp = cast<tensor::PadOp>(op);
765 auto maybeSrcBufferType =
766 bufferization::detail::asMemRefType(bufferization::getBufferType(
767 padOp.getSource(), options, state, invocationStack));
768 if (failed(maybeSrcBufferType))
769 return failure();
770 MemRefLayoutAttrInterface layout;
771 return cast<BufferLikeType>(
772 MemRefType::get(padOp.getResultType().getShape(),
773 padOp.getResultType().getElementType(), layout,
774 maybeSrcBufferType->getMemorySpace()));
775 }
776
777 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
778 const BufferizationOptions &options,
779 BufferizationState &state) const {
780 auto padOp = cast<tensor::PadOp>(op);
781 Location loc = padOp.getLoc();
782 RankedTensorType resultType = padOp.getResultType();
783 RankedTensorType srcType = padOp.getSourceType();
784
785 auto toValue = [&](OpFoldResult ofr) {
786 if (auto value = dyn_cast<Value>(ofr))
787 return value;
788 return arith::ConstantIndexOp::create(rewriter, loc,
790 .getResult();
791 };
792
793 // Compute dynamic result dimensions.
794 SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
795 SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
796 SmallVector<Value> dynamicSizes;
797 for (int64_t i = 0; i < resultType.getRank(); ++i) {
798 if (!resultType.isDynamicDim(i))
799 continue;
800 Value srcDim = tensor::DimOp::create(rewriter, loc, padOp.getSource(), i);
801 Value lowPad = toValue(mixedLowPad[i]);
802 Value highPad = toValue(mixedHighPad[i]);
803 AffineExpr s0, s1, s2;
804 bindSymbols(op->getContext(), s0, s1, s2);
805 AffineExpr sumExpr = s0 + s1 + s2;
806 Value sum = affine::AffineApplyOp::create(
807 rewriter, loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
808 dynamicSizes.push_back(sum);
809 }
810
811 // Allocate a buffer for the padded result.
812 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
813 rewriter, loc, padOp.getResult(), options, state,
814 /*copy=*/false);
815 if (failed(tensorAlloc))
816 return failure();
817
818 // tensor::PadOp is like tensor::GenerateOp: The only difference is that
819 // only a part of the generated tensor is needed. For simplicity, we reuse
820 // the same functionality here.
821 Value filledBuffer = lowerGenerateLikeOpBody(
822 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
823
824 // Create tensor::InsertSliceOp.
825 SmallVector<OpFoldResult> sliceSizes =
826 getMixedSizes(rewriter, loc, padOp.getSource());
827 SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
828 rewriter.getIndexAttr(1));
829 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
830 padOp, padOp.getSource(), filledBuffer,
831 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
832
833 return success();
834 }
835};
836
837/// Bufferization of tensor.rank. Replace with memref.rank.
838struct RankOpInterface
839 : public BufferizableOpInterface::ExternalModel<RankOpInterface,
840 tensor::RankOp> {
841 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
842 const AnalysisState &state) const {
843 // The op reads the tensor's metadata but not its contents.
844 return false;
845 }
846
847 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
848 const AnalysisState &state) const {
849 return false;
850 }
851
852 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
853 const AnalysisState &state) const {
854 return {};
855 }
856
857 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
858 const BufferizationOptions &options,
859 BufferizationState &state) const {
860 auto rankOp = cast<tensor::RankOp>(op);
861 FailureOr<Value> v =
862 getBuffer(rewriter, rankOp.getTensor(), options, state);
863 if (failed(v))
864 return failure();
865 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
866 *v);
867 return success();
868 }
869};
870
871/// Bufferization of tensor.reshape. Replace with memref.reshape.
872struct ReshapeOpInterface
873 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
874 tensor::ReshapeOp> {
875 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
876 const AnalysisState &state) const {
877 // Depending on the layout map, the source buffer may have to be copied.
878 auto reshapeOp = cast<tensor::ReshapeOp>(op);
879 return opOperand == reshapeOp.getShapeMutable();
880 }
881
882 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
883 const AnalysisState &state) const {
884 return false;
885 }
886
887 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
888 const AnalysisState &state) const {
889 // Only the 'source' operand aliases the result.
890 auto reshapeOp = cast<tensor::ReshapeOp>(op);
891 if (reshapeOp.getSourceMutable() != opOperand)
892 return {};
893 return {{op->getOpResult(0), BufferRelation::Equivalent}};
894 }
895
896 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
897 const BufferizationOptions &options,
898 BufferizationState &state) const {
899 auto reshapeOp = cast<tensor::ReshapeOp>(op);
900 FailureOr<Value> srcBuffer =
901 getBuffer(rewriter, reshapeOp.getSource(), options, state);
902 FailureOr<Value> shapeBuffer =
903 getBuffer(rewriter, reshapeOp.getShape(), options, state);
904 if (failed(srcBuffer) || failed(shapeBuffer))
905 return failure();
906 auto maybeResultMemRefType =
907 bufferization::getBufferType(reshapeOp.getResult(), options, state);
908 if (failed(maybeResultMemRefType))
909 return failure();
910
911 // memref.reshape requires the source buffer to have an identity layout.
912 // If the source memref does not have an identity layout, copy the source
913 // into a new buffer with an identity layout.
914 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
915 if (srcType && !srcType.getLayout().isIdentity()) {
916 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
917 rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
918 if (failed(tensorAlloc))
919 return failure();
920 auto memrefType = MemRefType::get(
921 srcType.getShape(), srcType.getElementType(), AffineMap(),
922 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
923 srcBuffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
924 memrefType, *tensorAlloc)
925 .getResult();
926 }
927
928 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
929 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
930 return success();
931 }
932
933 FailureOr<BufferLikeType>
934 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
935 const BufferizationState &state,
936 SmallVector<Value> &invocationStack) const {
937 auto reshapeOp = cast<tensor::ReshapeOp>(op);
938 assert(value == reshapeOp.getResult() && "unexpected value provided");
939 auto maybeSourceBufferType = bufferization::getBufferType(
940 reshapeOp.getSource(), options, state, invocationStack);
941 if (failed(maybeSourceBufferType))
942 return failure();
943 return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
944 reshapeOp.getResult().getType(),
945 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
946 }
947};
948
949/// Analysis of ParallelInsertSliceOp.
950struct ParallelInsertSliceOpInterface
951 : public BufferizableOpInterface::ExternalModel<
952 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
953 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
954 const AnalysisState &state) const {
955 return {};
956 }
957
958 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
959 const AnalysisState &state) const {
960 return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
961 }
962
963 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
964 const AnalysisState &state) const {
965 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
966 return opOperand == parallelInsertSliceOp.getDestMutable();
967 }
968
969 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
970 const BufferizationOptions &options,
971 BufferizationState &state) const {
972 OpBuilder::InsertionGuard g(rewriter);
973 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
974 InParallelOpInterface parallelCombiningParent =
975 parallelInsertSliceOp.getParallelCombiningParent();
976
977 // Bufferize the op outside of the in parallel terminator.
978 rewriter.setInsertionPoint(parallelCombiningParent);
979
980 // Get source and destination buffers.
981 FailureOr<Value> destBuffer =
982 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
983 if (failed(destBuffer))
984 return failure();
985 FailureOr<Value> srcBuffer =
986 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
987 if (failed(srcBuffer))
988 return failure();
989
990 // Take a subview of the destination buffer.
991 auto destBufferType = cast<MemRefType>(destBuffer->getType());
992 MemRefType subviewMemRefType =
993 memref::SubViewOp::inferRankReducedResultType(
994 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
995 parallelInsertSliceOp.getMixedOffsets(),
996 parallelInsertSliceOp.getMixedSizes(),
997 parallelInsertSliceOp.getMixedStrides());
998 Value subview = memref::SubViewOp::create(
999 rewriter, parallelInsertSliceOp.getLoc(), subviewMemRefType,
1000 *destBuffer, parallelInsertSliceOp.getMixedOffsets(),
1001 parallelInsertSliceOp.getMixedSizes(),
1002 parallelInsertSliceOp.getMixedStrides());
1003
1004 // This memcpy will fold away if everything bufferizes in-place.
1005 if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
1006 *srcBuffer, subview)))
1007 return failure();
1008
1009 // In case the source was allocated in the same block, make sure that the
1010 // deallocation op (if any) appears after the memcpy. By default, deallocs
1011 // are placed before the terminator, but this does not work for ForallOp
1012 // because the terminator does more than just yielding a value.
1013 //
1014 // Note: This is not a problem for the destination buffer because these are
1015 // assumed to always bufferize in-place.
1016 for (Operation *user : srcBuffer->getUsers()) {
1018 if (user->getBlock() == parallelCombiningParent->getBlock())
1019 rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
1020 break;
1021 }
1022 }
1023
1024 // Delete the op.
1025 rewriter.eraseOp(op);
1026 return success();
1027 }
1028
1029 /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1030 /// shouldn't create copy to resolve conflict.
1031 LogicalResult
1032 resolveConflicts(Operation *op, RewriterBase &rewriter,
1033 const AnalysisState &analysisState,
1034 const BufferizationState &bufferizationState) const {
1035 return success();
1036 }
1037};
1038
1039/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1040/// with a linalg.map. Similar to tensor.generate.
1041struct SplatOpInterface
1042 : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1043 tensor::SplatOp> {
1044
1045 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1046
1047 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1048 const BufferizationOptions &options,
1049 BufferizationState &state) const {
1050 OpBuilder::InsertionGuard g(rewriter);
1051 auto splatOp = cast<tensor::SplatOp>(op);
1052
1053 // Allocate memory.
1054 Location loc = op->getLoc();
1055 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1056 rewriter, loc, splatOp.getResult(), options, state,
1057 /*copy=*/false);
1058 if (failed(tensorAlloc))
1059 return failure();
1060
1061 // Create linalg::MapOp.
1062 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1063
1064 // TODO: Implement memory space for this op.
1065 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1066 return op->emitError("memory space not implemented yet");
1067
1068 auto linalgOp = linalg::MapOp::create(rewriter, loc, tensorType,
1069 /*inputs=*/ValueRange(),
1070 /*init=*/*tensorAlloc);
1071 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1072 linalgBody.addArgument(tensorType.getElementType(), loc);
1073
1074 // Create linalg::IndexOps.
1075 rewriter.setInsertionPointToStart(&linalgBody);
1076 linalg::YieldOp::create(rewriter, loc, splatOp.getInput());
1077 rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1078
1079 return success();
1080 }
1081};
1082
1083/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1084/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1085/// on subviews instead of memref.store.
1086struct ConcatOpInterface
1087 : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1088 tensor::ConcatOp> {
1089
1090 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1091
1092 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1093 const AnalysisState &state) const {
1094 return false;
1095 }
1096
1097 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1098 const AnalysisState &state) const {
1099 return true;
1100 }
1101
1102 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1103 const AnalysisState &state) const {
1104 return {};
1105 }
1106
1107 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1108 const BufferizationOptions &options,
1109 BufferizationState &state) const {
1110 OpBuilder::InsertionGuard g(rewriter);
1111 auto concatOp = cast<tensor::ConcatOp>(op);
1112
1113 // Allocate memory.
1114 Location loc = op->getLoc();
1115 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1116 rewriter, loc, concatOp.getResult(), options, state,
1117 /*copy=*/false);
1118 if (failed(tensorAlloc))
1119 return failure();
1120 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1121
1122 // TODO: Implement memory space for this op.
1123 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1124 return op->emitError("memory space not implemented yet");
1125
1126 MemRefLayoutAttrInterface layout;
1127 MemRefType memrefType =
1128 MemRefType::get(concatOp.getResultType().getShape(),
1129 concatOp.getResultType().getElementType(), layout);
1130 Value dstBuffer = bufferization::ToBufferOp::create(
1131 rewriter, op->getLoc(), memrefType, *tensorAlloc);
1132
1133 // Extract the dimension for the concat op
1134 uint64_t concatDim = concatOp.getDim();
1135 bool dynamicConcatDim = false;
1136
1137 SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1138 rewriter.getIndexAttr(0));
1139 SmallVector<OpFoldResult> strides(tensorType.getRank(),
1140 rewriter.getIndexAttr(1));
1141 SmallVector<OpFoldResult> sizes;
1142
1143 for (const auto &[dimIdx, dimSize] :
1144 llvm::enumerate(tensorType.getShape())) {
1145 if (dimSize == ShapedType::kDynamic) {
1146 auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
1147 sizes.push_back(dimOp.getResult());
1148 if (dimIdx == concatDim)
1149 dynamicConcatDim = true;
1150 } else {
1151 sizes.push_back(rewriter.getIndexAttr(dimSize));
1152 }
1153 }
1154
1155 int64_t concatDimOffset = 0;
1156 std::optional<Value> dynamicOffset;
1157 std::optional<Value> dynamicSize;
1158 if (dynamicConcatDim) {
1159 // One or more operands have dynamic size, so we must accumulate the
1160 // offset with arith ops.
1161 dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
1162 }
1163
1164 for (auto operand : concatOp.getInputs()) {
1165 // Get the buffer for the operand.
1166 FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
1167 if (failed(srcBuffer))
1168 return failure();
1169
1170 // Each operand may have a different size along the concat dimension,
1171 // so the offset on that axis must accumulate through the loop, and the
1172 // size must change to the size of the current operand.
1173 auto operandTensorType = cast<RankedTensorType>(operand.getType());
1174 int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1175
1176 if (dynamicConcatDim) {
1177 offsets[concatDim] = dynamicOffset.value();
1178 dynamicSize =
1179 memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
1180 .getResult();
1181 sizes[concatDim] = dynamicSize.value();
1182 } else {
1183 sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1184 offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1185 }
1186
1187 // Create a subview of the destination buffer.
1188 auto dstMemrefType = cast<MemRefType>(memrefType);
1189 MemRefType subviewMemRefType =
1190 memref::SubViewOp::inferRankReducedResultType(
1191 operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1192 strides);
1193 Value subview = memref::SubViewOp::create(
1194 rewriter, loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1195
1196 // Copy the source buffer into the destination subview.
1197 if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1198 return failure();
1199
1200 if (dynamicConcatDim) {
1201 dynamicOffset = arith::AddIOp::create(
1202 rewriter, loc, dynamicOffset.value(), dynamicSize.value());
1203 } else {
1204 concatDimOffset += operandConcatDimSize;
1205 }
1206 }
1207
1208 replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1209 return success();
1210 }
1211};
1212
1213} // namespace
1214} // namespace tensor
1215} // namespace mlir
1216
1218 DialectRegistry &registry) {
1219 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1220 CastOp::attachInterface<CastOpInterface>(*ctx);
1221 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1222 ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1223 DimOp::attachInterface<DimOpInterface>(*ctx);
1224 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1225 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1226 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1227 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1228 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1229 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1230 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1231 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1232 PadOp::attachInterface<PadOpInterface>(*ctx);
1233 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1234 *ctx);
1235 RankOp::attachInterface<RankOpInterface>(*ctx);
1236 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1237 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1238
1239 // Load additional dialects of which ops may get created.
1240 ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1241 });
1242
1243 // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1244 // they are registered.
1246}
return success()
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasEffect< MemoryEffects::Free >(Operation *)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
OpResult getOpResult(unsigned idx)
Definition Operation.h:421
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:846
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
Block & front()
Definition Region.h:65
unsigned getNumArguments()
Definition Region.h:123
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325