MLIR  22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/Operation.h"
24 
25 using namespace mlir;
26 using namespace mlir::bufferization;
27 using namespace mlir::tensor;
28 
29 namespace mlir {
30 namespace tensor {
31 namespace {
32 
33 struct CastOpInterface
34  : public BufferizableOpInterface::ExternalModel<CastOpInterface,
35  tensor::CastOp> {
36  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
37  const AnalysisState &state) const {
38  return false;
39  }
40 
41  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
42  const AnalysisState &state) const {
43  return false;
44  }
45 
46  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47  const AnalysisState &state) const {
48  return {{op->getResult(0), BufferRelation::Equivalent}};
49  }
50 
51  FailureOr<BufferLikeType>
53  const BufferizationState &state,
54  SmallVector<Value> &invocationStack) const {
55  auto castOp = cast<tensor::CastOp>(op);
56  auto maybeSrcBufferType =
58  castOp.getSource(), options, state, invocationStack));
59  if (failed(maybeSrcBufferType))
60  return failure();
61  Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
62 
63  // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
64  // type in case the input is an unranked tensor type.
65 
66  // Case 1: Casting an unranked tensor
67  if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
68  // When casting to a ranked tensor, we cannot infer any static offset or
69  // strides from the source. Assume fully dynamic.
70  return cast<BufferLikeType>(
71  getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
72  }
73 
74  // Case 2: Casting to an unranked tensor type
75  if (isa<UnrankedTensorType>(castOp.getType())) {
76  return cast<BufferLikeType>(
77  getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
78  }
79 
80  // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
81  // change.
82  auto rankedResultType = cast<RankedTensorType>(castOp.getType());
83  return cast<BufferLikeType>(MemRefType::get(
84  rankedResultType.getShape(), rankedResultType.getElementType(),
85  llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
86  }
87 
88  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
90  BufferizationState &state) const {
91  auto castOp = cast<tensor::CastOp>(op);
92 
93  // The result buffer still has the old (pre-cast) type.
94  FailureOr<Value> resultBuffer =
95  getBuffer(rewriter, castOp.getSource(), options, state);
96  if (failed(resultBuffer))
97  return failure();
98 
99  // Compute the new type.
100  auto resultMemRefType =
101  bufferization::getBufferType(castOp.getResult(), options, state);
102  if (failed(resultMemRefType))
103  return failure();
104  if (resultBuffer->getType() == *resultMemRefType) {
105  // This cast is a no-op.
106  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
107  return success();
108  }
109 
110  // Replace the op with a memref.cast.
111  assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
112  *resultMemRefType) &&
113  "CallOp::bufferize: cast incompatible");
114  replaceOpWithNewBufferizedOp<memref::CastOp>(
115  rewriter, op, *resultMemRefType, *resultBuffer);
116 
117  return success();
118  }
119 };
120 
121 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
122 struct CollapseShapeOpInterface
123  : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
124  tensor::CollapseShapeOp> {
125  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
126  const AnalysisState &state) const {
127  // tensor.collapse_shape may reallocate, at which point the source buffer is
128  // copied. I.e., there will be a memory read side effect on the bufferized
129  // source. This function conservatively returns "true" because whether a
130  // copy will be created or not is not known at this point.
131  return true;
132  }
133 
134  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
135  const AnalysisState &state) const {
136  return false;
137  }
138 
139  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
140  const AnalysisState &state) const {
141  // TODO: CollapseShapeOp may allocate at runtime.
142  return {{op->getOpResult(0), BufferRelation::Equivalent}};
143  }
144 
145  FailureOr<BufferLikeType>
147  const BufferizationState &state,
148  SmallVector<Value> &invocationStack) const {
149  auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
150  auto maybeSrcBufferType = bufferization::getBufferType(
151  collapseShapeOp.getSrc(), options, state, invocationStack);
152  if (failed(maybeSrcBufferType))
153  return failure();
154  auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
155  bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
156  srcBufferType, collapseShapeOp.getReassociationIndices());
157 
158  if (!canBeCollapsed) {
159  // If dims cannot be collapsed, this op bufferizes to a new allocation.
160  RankedTensorType tensorResultType = collapseShapeOp.getResultType();
161  return cast<BufferLikeType>(
163  tensorResultType, srcBufferType.getMemorySpace()));
164  }
165 
166  return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
167  srcBufferType, collapseShapeOp.getReassociationIndices()));
168  }
169 
170  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
172  BufferizationState &state) const {
173  auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
174  RankedTensorType tensorResultType = collapseShapeOp.getResultType();
175  FailureOr<Value> maybeBuffer =
176  getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
177  if (failed(maybeBuffer))
178  return failure();
179  Value buffer = *maybeBuffer;
180  auto bufferType = cast<MemRefType>(buffer.getType());
181 
182  if (tensorResultType.getRank() == 0) {
183  // 0-d collapses must go through a different op builder.
184  MemRefType resultType;
185 
186  if (bufferType.getLayout().isIdentity()) {
187  // Standard layout: result type has no offset.
188  MemRefLayoutAttrInterface layout;
189  resultType = MemRefType::get({}, tensorResultType.getElementType(),
190  layout, bufferType.getMemorySpace());
191  } else {
192  // Source memref has a layout map: result type has the same offset as
193  // the source type.
194  SmallVector<int64_t> strides;
195  int64_t offset;
196  if (failed(bufferType.getStridesAndOffset(strides, offset)))
197  return failure();
198  resultType = MemRefType::get(
199  {}, tensorResultType.getElementType(),
200  StridedLayoutAttr::get(op->getContext(), offset, {}),
201  bufferType.getMemorySpace());
202  }
203 
204  replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
205  rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
206  return success();
207  }
208 
209  // If the dims are not collapsible (due to an incompatible source layout
210  // map), force an out-of-place bufferization, i.e., a buffer copy. This
211  // newly allocated buffer will have no layout map and thus be collapsible.
212  bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
213  bufferType, collapseShapeOp.getReassociationIndices());
214  if (!canBeCollapsed) {
215  // TODO: Create alloc_tensor ops during TensorCopyInsertion.
216  AnalysisState analysisState(options);
217  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
218  rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
219  if (failed(tensorAlloc))
220  return failure();
221  auto memrefType =
222  MemRefType::get(collapseShapeOp.getSrcType().getShape(),
223  collapseShapeOp.getSrcType().getElementType(),
224  AffineMap(), bufferType.getMemorySpace());
225  buffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
226  memrefType, *tensorAlloc);
227  }
228 
229  // Result type is inferred by the builder.
230  replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
231  rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
232  return success();
233  }
234 };
235 
236 /// Bufferization of tensor.dim. Replace with memref.dim.
237 struct DimOpInterface
238  : public BufferizableOpInterface::ExternalModel<DimOpInterface,
239  tensor::DimOp> {
240  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
241  const AnalysisState &state) const {
242  // The op reads the tensor's metadata but not its contents.
243  return false;
244  }
245 
246  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
247  const AnalysisState &state) const {
248  return false;
249  }
250 
251  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
252  const AnalysisState &state) const {
253  return {};
254  }
255 
256  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
258  BufferizationState &state) const {
259  auto dimOp = cast<tensor::DimOp>(op);
260  FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
261  if (failed(v))
262  return failure();
263  replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
264  dimOp.getIndex());
265  return success();
266  }
267 };
268 
269 /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
270 struct EmptyOpInterface
271  : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
272  tensor::EmptyOp> {
273  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
274 
275  bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
276  const AnalysisState &state) const {
277  // The returned tensor does not have specified contents.
278  return false;
279  }
280 
281  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
283  BufferizationState &state) const {
284  auto emptyOp = cast<tensor::EmptyOp>(op);
285 
286  // Optimization: Fold away the op if it has no uses.
287  if (op->getUses().empty()) {
288  rewriter.eraseOp(op);
289  return success();
290  }
291 
292  // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
293  FailureOr<Value> allocTensor = allocateTensorForShapedValue(
294  rewriter, op->getLoc(), emptyOp.getResult(), options, state,
295  /*copy=*/false);
296  if (failed(allocTensor))
297  return failure();
298  rewriter.replaceOp(op, *allocTensor);
299  return success();
300  }
301 };
302 
303 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
304 struct ExpandShapeOpInterface
305  : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
306  tensor::ExpandShapeOp> {
307  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
308  const AnalysisState &state) const {
309  // In contrast to tensor.collapse_shape, this op can always be bufferized
310  // without a copy.
311  return false;
312  }
313 
314  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
315  const AnalysisState &state) const {
316  return false;
317  }
318 
319  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
320  const AnalysisState &state) const {
321  return {{op->getOpResult(0), BufferRelation::Equivalent}};
322  }
323 
324  FailureOr<BufferLikeType>
326  const BufferizationState &state,
327  SmallVector<Value> &invocationStack) const {
328  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
329  auto maybeSrcBufferType = bufferization::getBufferType(
330  expandShapeOp.getSrc(), options, state, invocationStack);
331  if (failed(maybeSrcBufferType))
332  return failure();
333  auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
334  auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
335  srcBufferType, expandShapeOp.getResultType().getShape(),
336  expandShapeOp.getReassociationIndices());
337  if (failed(maybeResultType))
338  return failure();
339  return cast<BufferLikeType>(*maybeResultType);
340  }
341 
342  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
344  BufferizationState &state) const {
345  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
346  auto tensorResultType = expandShapeOp.getResultType();
347  FailureOr<Value> buffer =
348  getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
349  if (failed(buffer))
350  return failure();
351 
352  auto memrefExpandShape = memref::ExpandShapeOp::create(
353  rewriter, op->getLoc(), tensorResultType.getShape(), *buffer,
354  expandShapeOp.getReassociationIndices(),
355  expandShapeOp.getMixedOutputShape());
356  replaceOpWithBufferizedValues(rewriter, op,
357  memrefExpandShape->getResults());
358  return success();
359  }
360 };
361 
362 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
363 struct ExtractSliceOpInterface
364  : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
365  tensor::ExtractSliceOp> {
366  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
367  const AnalysisState &state) const {
368  return false;
369  }
370 
371  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
372  const AnalysisState &state) const {
373  return false;
374  }
375 
376  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
377  const AnalysisState &state) const {
378  return {{op->getOpResult(0), BufferRelation::Unknown}};
379  }
380 
381  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
383  BufferizationState &state) const {
384  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
385  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
386  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
387  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
388  Location loc = extractSliceOp.getLoc();
389 
390  // Get source buffer.
391  FailureOr<Value> srcMemref =
392  getBuffer(rewriter, extractSliceOp.getSource(), options, state);
393  if (failed(srcMemref))
394  return failure();
395 
396  // Take a subview of the source buffer.
397  auto resultMemrefType = bufferization::getBufferType(
398  extractSliceOp.getResult(), options, state);
399  if (failed(resultMemrefType))
400  return failure();
401  Value subView = memref::SubViewOp::create(
402  rewriter, loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
403  mixedOffsets, mixedSizes, mixedStrides);
404 
405  replaceOpWithBufferizedValues(rewriter, op, subView);
406  return success();
407  }
408 
409  FailureOr<BufferLikeType>
411  const BufferizationState &state,
412  SmallVector<Value> &invocationStack) const {
413  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
414  assert(value == extractSliceOp.getResult() && "invalid value");
415  auto srcMemrefType = bufferization::getBufferType(
416  extractSliceOp.getSource(), options, state, invocationStack);
417  if (failed(srcMemrefType))
418  return failure();
419  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
420  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
421  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
422  return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
423  extractSliceOp.getType().getShape(),
424  llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
425  mixedStrides));
426  }
427 };
428 
429 /// Bufferization of tensor.extract. Replace with memref.load.
430 struct ExtractOpInterface
431  : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
432  tensor::ExtractOp> {
433  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
434  const AnalysisState &state) const {
435  return true;
436  }
437 
438  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
439  const AnalysisState &state) const {
440  return false;
441  }
442 
443  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
444  const AnalysisState &state) const {
445  return {};
446  }
447 
448  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
450  BufferizationState &state) const {
451  auto extractOp = cast<tensor::ExtractOp>(op);
452  FailureOr<Value> srcMemref =
453  getBuffer(rewriter, extractOp.getTensor(), options, state);
454  if (failed(srcMemref))
455  return failure();
456  replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
457  extractOp.getIndices());
458  return success();
459  }
460 };
461 
462 // Implements backtracking to traverse indices of the output buffer while
463 // iterating over op.elements().
464 static void createStores(RewriterBase &rewriter, Location loc, int dim,
465  Value buffer, ArrayRef<int64_t> shape,
466  ArrayRef<Value> constants,
467  OperandRange::iterator &elementIt,
468  SmallVectorImpl<Value> &indices) {
469  if (dim == static_cast<int>(shape.size()) - 1) {
470  for (int i = 0; i < shape.back(); ++i) {
471  indices.back() = constants[i];
472  memref::StoreOp::create(rewriter, loc, *elementIt, buffer, indices);
473  ++elementIt;
474  }
475  return;
476  }
477  for (int i = 0; i < shape[dim]; ++i) {
478  indices[dim] = constants[i];
479  createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
480  indices);
481  }
482 }
483 
484 /// Bufferization of tensor.from_elements.
485 struct FromElementsOpInterface
486  : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
487  tensor::FromElementsOp> {
488 
489  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
490 
491  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
493  BufferizationState &state) const {
494  auto fromElementsOp = cast<tensor::FromElementsOp>(op);
495  auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
496 
497  // Allocate a buffer for the result.
498  Location loc = op->getLoc();
499  auto shape = tensorType.getShape();
500  // TODO: Create alloc_tensor ops during TensorCopyInsertion.
501  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
502  rewriter, loc, fromElementsOp.getResult(), options, state,
503  /*copy=*/false);
504  if (failed(tensorAlloc))
505  return failure();
506  FailureOr<BufferLikeType> memrefType =
507  bufferization::getBufferType(*tensorAlloc, options, state);
508  if (failed(memrefType))
509  return failure();
510  Value buffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
511  *memrefType, *tensorAlloc);
512 
513  // Case: tensor<0xelem_type>.
514  if (fromElementsOp.getElements().empty()) {
515  replaceOpWithBufferizedValues(rewriter, op, buffer);
516  return success();
517  }
518 
519  // Case: tensor<elem_type>.
520  if (shape.empty()) {
521  memref::StoreOp::create(rewriter, loc,
522  fromElementsOp.getElements().front(), buffer);
523  replaceOpWithBufferizedValues(rewriter, op, buffer);
524  return success();
525  }
526 
527  // Create constants for the range of possible indices [0, max{shape_i}).
528  auto maxDim = *llvm::max_element(shape);
529  SmallVector<Value, 2> constants;
530  constants.reserve(maxDim);
531  for (int i = 0; i < maxDim; ++i)
532  constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i));
533 
534  // Traverse all `elements` and create `memref.store` ops.
535  auto elementIt = fromElementsOp.getElements().begin();
536  SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
537  createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
538  indices);
539 
540  replaceOpWithBufferizedValues(rewriter, op, buffer);
541 
542  return success();
543  }
544 };
545 
546 /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
547 /// Such ops are lowered to linalg.map with the given tensor as a destination.
548 ///
549 /// Example:
550 /// ```
551 /// %r = tensor.generate %x, %y {
552 /// ^bb0(%arg0: index, %arg1: index):
553 /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
554 /// tensor.yield %0 : index
555 /// } : tensor<?x?xindex>
556 /// ```
557 ///
558 /// Is lowered to:
559 /// ```
560 /// linalg.map ins() outs(%dest) {
561 /// %d0 = linalg.index 0 : index
562 /// %d1 = linalg.index 1 : index
563 /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
564 /// linalg.yield %0 : index
565 /// }
566 /// ```
567 static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
568  Value tensorDestination,
569  ValueRange dynamicSizes,
570  Region &generateBody) {
571  assert(generateBody.hasOneBlock() && "expected body with single block");
572  auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
573  assert(generateBody.getNumArguments() == tensorType.getRank() &&
574  "rank mismatch");
575 
576  // Create linalg::MapOp.
577  OpBuilder::InsertionGuard g(rewriter);
578  auto linalgOp =
579  linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
580  /*init=*/tensorDestination);
581  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
582 
583  // Create linalg::IndexOps.
584  rewriter.setInsertionPointToStart(&linalgBody);
585  SmallVector<Value> indices;
586  for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
587  indices.push_back(linalg::IndexOp::create(rewriter, loc, dim));
588 
589  // Move over body.
590  rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
591  auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
592  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
593 
594  return linalgOp.getResult()[0];
595 }
596 
597 /// Bufferization of tensor.generate.
598 struct GenerateOpInterface
599  : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
600  tensor::GenerateOp> {
601 
602  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
603 
604  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
606  BufferizationState &state) const {
607  auto generateOp = cast<tensor::GenerateOp>(op);
608 
609  auto type = generateOp.getResult().getType();
610 
611  // TODO: Implement memory space for this op.
612  if (options.defaultMemorySpaceFn(type) != Attribute())
613  return op->emitError("memory space not implemented yet");
614 
615  // Allocate memory.
616  Location loc = op->getLoc();
617  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
618  rewriter, loc, generateOp.getResult(), options, state,
619  /*copy=*/false);
620  if (failed(tensorAlloc))
621  return failure();
622 
623  Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
624  generateOp.getDynamicExtents(),
625  generateOp.getBody());
626  rewriter.replaceOp(generateOp, result);
627 
628  return success();
629  }
630 };
631 
632 /// Bufferization of tensor.insert. Replace with memref.store.
633 ///
634 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
635 /// implementations for DestinationStyle ops.
636 struct InsertOpInterface
637  : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
638  tensor::InsertOp> {
639  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
641  BufferizationState &state) const {
642  auto insertOp = cast<tensor::InsertOp>(op);
643  FailureOr<Value> destMemref =
644  getBuffer(rewriter, insertOp.getDest(), options, state);
645  if (failed(destMemref))
646  return failure();
647  memref::StoreOp::create(rewriter, insertOp.getLoc(), insertOp.getScalar(),
648  *destMemref, insertOp.getIndices());
649  replaceOpWithBufferizedValues(rewriter, op, *destMemref);
650  return success();
651  }
652 };
653 
654 template <typename InsertOpTy>
655 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
656  OpOperand &opOperand) {
657  // The source is always read.
658  if (opOperand == insertSliceOp.getSourceMutable())
659  return true;
660 
661  // For the destination, it depends...
662  assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
663 
664  // Dest is not read if it is entirely overwritten. E.g.:
665  // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
666  bool allOffsetsZero =
667  llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
668  RankedTensorType destType = insertSliceOp.getDestType();
669  bool sizesMatchDestSizes =
670  areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
671  bool allStridesOne =
672  areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
673  return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
674 }
675 
676 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
677 /// certain circumstances, this op can also be a no-op.
678 ///
679 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
680 /// implementations for DestinationStyle ops.
681 struct InsertSliceOpInterface
682  : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
683  tensor::InsertSliceOp> {
684  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
685  const AnalysisState &state) const {
686  return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
687  opOperand);
688  }
689 
690  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
692  BufferizationState &state) const {
693  // insert_slice ops arise from tiling and bufferizing them out-of-place is
694  // generally a deal breaker. When used with loops, this ends up cloning the
695  // whole tensor on every single iteration and is a symptom of a
696  // catastrophically bad scheduling decision.
697  // TODO: be very loud about it or even consider failing the pass.
698  auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
699  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
700  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
701  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
702  Location loc = insertSliceOp.getLoc();
703 
704  // Get destination buffer.
705  FailureOr<Value> dstMemref =
706  getBuffer(rewriter, insertSliceOp.getDest(), options, state);
707  if (failed(dstMemref))
708  return failure();
709 
710  // Take a subview of the destination buffer.
711  auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
712  MemRefType subviewMemRefType =
713  memref::SubViewOp::inferRankReducedResultType(
714  insertSliceOp.getSourceType().getShape(), dstMemrefType,
715  mixedOffsets, mixedSizes, mixedStrides);
716  Value subView =
717  memref::SubViewOp::create(rewriter, loc, subviewMemRefType, *dstMemref,
718  mixedOffsets, mixedSizes, mixedStrides);
719 
720  // Copy tensor. If this tensor.insert_slice has a matching
721  // tensor.extract_slice, the copy operation will eventually fold away.
722  FailureOr<Value> srcMemref =
723  getBuffer(rewriter, insertSliceOp.getSource(), options, state);
724  if (failed(srcMemref))
725  return failure();
726  if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
727  return failure();
728 
729  replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
730  return success();
731  }
732 };
733 
734 /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
735 /// linalg.map + insert_slice.
736 /// For best performance, vectorize before bufferization (better performance in
737 /// case of padding with a constant).
738 struct PadOpInterface
739  : public BufferizableOpInterface::ExternalModel<PadOpInterface,
740  tensor::PadOp> {
741  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
742 
743  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
744  const AnalysisState &state) const {
745  return true;
746  }
747 
748  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
749  const AnalysisState &state) const {
750  return false;
751  }
752 
753  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
754  const AnalysisState &state) const {
755  return {};
756  }
757 
758  FailureOr<BufferLikeType>
760  const BufferizationState &state,
761  SmallVector<Value> &invocationStack) const {
762  // Infer memory space from the source tensor.
763  auto padOp = cast<tensor::PadOp>(op);
764  auto maybeSrcBufferType =
766  padOp.getSource(), options, state, invocationStack));
767  if (failed(maybeSrcBufferType))
768  return failure();
769  MemRefLayoutAttrInterface layout;
770  return cast<BufferLikeType>(
771  MemRefType::get(padOp.getResultType().getShape(),
772  padOp.getResultType().getElementType(), layout,
773  maybeSrcBufferType->getMemorySpace()));
774  }
775 
776  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
778  BufferizationState &state) const {
779  auto padOp = cast<tensor::PadOp>(op);
780  Location loc = padOp.getLoc();
781  RankedTensorType resultType = padOp.getResultType();
782  RankedTensorType srcType = padOp.getSourceType();
783 
784  auto toValue = [&](OpFoldResult ofr) {
785  if (auto value = dyn_cast<Value>(ofr))
786  return value;
787  return arith::ConstantIndexOp::create(rewriter, loc,
788  *getConstantIntValue(ofr))
789  .getResult();
790  };
791 
792  // Compute dynamic result dimensions.
793  SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
794  SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
795  SmallVector<Value> dynamicSizes;
796  for (int64_t i = 0; i < resultType.getRank(); ++i) {
797  if (!resultType.isDynamicDim(i))
798  continue;
799  Value srcDim = tensor::DimOp::create(rewriter, loc, padOp.getSource(), i);
800  Value lowPad = toValue(mixedLowPad[i]);
801  Value highPad = toValue(mixedHighPad[i]);
802  AffineExpr s0, s1, s2;
803  bindSymbols(op->getContext(), s0, s1, s2);
804  AffineExpr sumExpr = s0 + s1 + s2;
805  Value sum = affine::AffineApplyOp::create(
806  rewriter, loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
807  dynamicSizes.push_back(sum);
808  }
809 
810  // Allocate a buffer for the padded result.
811  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
812  rewriter, loc, padOp.getResult(), options, state,
813  /*copy=*/false);
814  if (failed(tensorAlloc))
815  return failure();
816 
817  // tensor::PadOp is like tensor::GenerateOp: The only difference is that
818  // only a part of the generated tensor is needed. For simplicity, we reuse
819  // the same functionality here.
820  Value filledBuffer = lowerGenerateLikeOpBody(
821  rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
822 
823  // Create tensor::InsertSliceOp.
824  SmallVector<OpFoldResult> sliceSizes =
825  getMixedSizes(rewriter, loc, padOp.getSource());
826  SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
827  rewriter.getIndexAttr(1));
828  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
829  padOp, padOp.getSource(), filledBuffer,
830  /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
831 
832  return success();
833  }
834 };
835 
836 /// Bufferization of tensor.rank. Replace with memref.rank.
837 struct RankOpInterface
838  : public BufferizableOpInterface::ExternalModel<RankOpInterface,
839  tensor::RankOp> {
840  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
841  const AnalysisState &state) const {
842  // The op reads the tensor's metadata but not its contents.
843  return false;
844  }
845 
846  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
847  const AnalysisState &state) const {
848  return false;
849  }
850 
851  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
852  const AnalysisState &state) const {
853  return {};
854  }
855 
856  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
858  BufferizationState &state) const {
859  auto rankOp = cast<tensor::RankOp>(op);
860  FailureOr<Value> v =
861  getBuffer(rewriter, rankOp.getTensor(), options, state);
862  if (failed(v))
863  return failure();
864  replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
865  *v);
866  return success();
867  }
868 };
869 
870 /// Bufferization of tensor.reshape. Replace with memref.reshape.
871 struct ReshapeOpInterface
872  : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
873  tensor::ReshapeOp> {
874  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
875  const AnalysisState &state) const {
876  // Depending on the layout map, the source buffer may have to be copied.
877  auto reshapeOp = cast<tensor::ReshapeOp>(op);
878  return opOperand == reshapeOp.getShapeMutable();
879  }
880 
881  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
882  const AnalysisState &state) const {
883  return false;
884  }
885 
886  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
887  const AnalysisState &state) const {
888  // Only the 'source' operand aliases the result.
889  auto reshapeOp = cast<tensor::ReshapeOp>(op);
890  if (reshapeOp.getSourceMutable() != opOperand)
891  return {};
892  return {{op->getOpResult(0), BufferRelation::Equivalent}};
893  }
894 
895  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
897  BufferizationState &state) const {
898  auto reshapeOp = cast<tensor::ReshapeOp>(op);
899  FailureOr<Value> srcBuffer =
900  getBuffer(rewriter, reshapeOp.getSource(), options, state);
901  FailureOr<Value> shapeBuffer =
902  getBuffer(rewriter, reshapeOp.getShape(), options, state);
903  if (failed(srcBuffer) || failed(shapeBuffer))
904  return failure();
905  auto maybeResultMemRefType =
906  bufferization::getBufferType(reshapeOp.getResult(), options, state);
907  if (failed(maybeResultMemRefType))
908  return failure();
909 
910  // memref.reshape requires the source buffer to have an identity layout.
911  // If the source memref does not have an identity layout, copy the source
912  // into a new buffer with an identity layout.
913  auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
914  if (srcType && !srcType.getLayout().isIdentity()) {
915  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
916  rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
917  if (failed(tensorAlloc))
918  return failure();
919  auto memrefType = MemRefType::get(
920  srcType.getShape(), srcType.getElementType(), AffineMap(),
921  cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
922  srcBuffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
923  memrefType, *tensorAlloc)
924  .getResult();
925  }
926 
927  replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
928  rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
929  return success();
930  }
931 
932  FailureOr<BufferLikeType>
934  const BufferizationState &state,
935  SmallVector<Value> &invocationStack) const {
936  auto reshapeOp = cast<tensor::ReshapeOp>(op);
937  assert(value == reshapeOp.getResult() && "unexpected value provided");
938  auto maybeSourceBufferType = bufferization::getBufferType(
939  reshapeOp.getSource(), options, state, invocationStack);
940  if (failed(maybeSourceBufferType))
941  return failure();
942  return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
943  reshapeOp.getResult().getType(),
944  cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
945  }
946 };
947 
948 /// Analysis of ParallelInsertSliceOp.
949 struct ParallelInsertSliceOpInterface
950  : public BufferizableOpInterface::ExternalModel<
951  ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
952  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
953  const AnalysisState &state) const {
954  return {};
955  }
956 
957  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
958  const AnalysisState &state) const {
959  return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
960  }
961 
962  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
963  const AnalysisState &state) const {
964  auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
965  return opOperand == parallelInsertSliceOp.getDestMutable();
966  }
967 
968  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
970  BufferizationState &state) const {
971  OpBuilder::InsertionGuard g(rewriter);
972  auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
973  ParallelCombiningOpInterface parallelCombiningParent =
974  parallelInsertSliceOp.getParallelCombiningParent();
975 
976  // Bufferize the op outside of the parallel combining terminator.
977  rewriter.setInsertionPoint(parallelCombiningParent);
978 
979  // Get source and destination buffers.
980  FailureOr<Value> destBuffer =
981  getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
982  if (failed(destBuffer))
983  return failure();
984  FailureOr<Value> srcBuffer =
985  getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
986  if (failed(srcBuffer))
987  return failure();
988 
989  // Take a subview of the destination buffer.
990  auto destBufferType = cast<MemRefType>(destBuffer->getType());
991  MemRefType subviewMemRefType =
992  memref::SubViewOp::inferRankReducedResultType(
993  parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
994  parallelInsertSliceOp.getMixedOffsets(),
995  parallelInsertSliceOp.getMixedSizes(),
996  parallelInsertSliceOp.getMixedStrides());
997  Value subview = memref::SubViewOp::create(
998  rewriter, parallelInsertSliceOp.getLoc(), subviewMemRefType,
999  *destBuffer, parallelInsertSliceOp.getMixedOffsets(),
1000  parallelInsertSliceOp.getMixedSizes(),
1001  parallelInsertSliceOp.getMixedStrides());
1002 
1003  // This memcpy will fold away if everything bufferizes in-place.
1004  if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
1005  *srcBuffer, subview)))
1006  return failure();
1007 
1008  // In case the source was allocated in the same block, make sure that the
1009  // deallocation op (if any) appears after the memcpy. By default, deallocs
1010  // are placed before the terminator, but this does not work for ForallOp
1011  // because the terminator does more than just yielding a value.
1012  //
1013  // Note: This is not a problem for the destination buffer because these are
1014  // assumed to always bufferize in-place.
1015  for (Operation *user : srcBuffer->getUsers()) {
1016  if (hasEffect<MemoryEffects::Free>(user)) {
1017  if (user->getBlock() == parallelCombiningParent->getBlock())
1018  rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
1019  break;
1020  }
1021  }
1022 
1023  // Delete the op.
1024  rewriter.eraseOp(op);
1025  return success();
1026  }
1027 
1028  /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1029  /// shouldn't create copy to resolve conflict.
1030  LogicalResult
1031  resolveConflicts(Operation *op, RewriterBase &rewriter,
1032  const AnalysisState &analysisState,
1033  const BufferizationState &bufferizationState) const {
1034  return success();
1035  }
1036 };
1037 
1038 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1039 /// with a linalg.map. Similar to tensor.generate.
1040 struct SplatOpInterface
1041  : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1042  tensor::SplatOp> {
1043 
1044  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1045 
1046  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1048  BufferizationState &state) const {
1049  OpBuilder::InsertionGuard g(rewriter);
1050  auto splatOp = cast<tensor::SplatOp>(op);
1051 
1052  // Allocate memory.
1053  Location loc = op->getLoc();
1054  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1055  rewriter, loc, splatOp.getResult(), options, state,
1056  /*copy=*/false);
1057  if (failed(tensorAlloc))
1058  return failure();
1059 
1060  // Create linalg::MapOp.
1061  auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1062 
1063  // TODO: Implement memory space for this op.
1064  if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1065  return op->emitError("memory space not implemented yet");
1066 
1067  auto linalgOp = linalg::MapOp::create(rewriter, loc, tensorType,
1068  /*inputs=*/ValueRange(),
1069  /*init=*/*tensorAlloc);
1070  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1071 
1072  // Create linalg::IndexOps.
1073  rewriter.setInsertionPointToStart(&linalgBody);
1074  linalg::YieldOp::create(rewriter, loc, splatOp.getInput());
1075  rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1076 
1077  return success();
1078  }
1079 };
1080 
1081 /// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1082 /// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1083 /// on subviews instead of memref.store.
1084 struct ConcatOpInterface
1085  : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1086  tensor::ConcatOp> {
1087 
1088  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1089 
1090  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1091  const AnalysisState &state) const {
1092  return false;
1093  }
1094 
1095  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1096  const AnalysisState &state) const {
1097  return true;
1098  }
1099 
1100  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1101  const AnalysisState &state) const {
1102  return {};
1103  }
1104 
1105  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1107  BufferizationState &state) const {
1108  OpBuilder::InsertionGuard g(rewriter);
1109  auto concatOp = cast<tensor::ConcatOp>(op);
1110 
1111  // Allocate memory.
1112  Location loc = op->getLoc();
1113  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1114  rewriter, loc, concatOp.getResult(), options, state,
1115  /*copy=*/false);
1116  if (failed(tensorAlloc))
1117  return failure();
1118  auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1119 
1120  // TODO: Implement memory space for this op.
1121  if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1122  return op->emitError("memory space not implemented yet");
1123 
1124  MemRefLayoutAttrInterface layout;
1125  MemRefType memrefType =
1126  MemRefType::get(concatOp.getResultType().getShape(),
1127  concatOp.getResultType().getElementType(), layout);
1128  Value dstBuffer = bufferization::ToBufferOp::create(
1129  rewriter, op->getLoc(), memrefType, *tensorAlloc);
1130 
1131  // Extract the dimension for the concat op
1132  uint64_t concatDim = concatOp.getDim();
1133  bool dynamicConcatDim = false;
1134 
1135  SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1136  rewriter.getIndexAttr(0));
1137  SmallVector<OpFoldResult> strides(tensorType.getRank(),
1138  rewriter.getIndexAttr(1));
1140 
1141  for (const auto &[dimIdx, dimSize] :
1142  llvm::enumerate(tensorType.getShape())) {
1143  if (dimSize == ShapedType::kDynamic) {
1144  auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
1145  sizes.push_back(dimOp.getResult());
1146  if (dimIdx == concatDim)
1147  dynamicConcatDim = true;
1148  } else {
1149  sizes.push_back(rewriter.getIndexAttr(dimSize));
1150  }
1151  }
1152 
1153  int64_t concatDimOffset = 0;
1154  std::optional<Value> dynamicOffset;
1155  std::optional<Value> dynamicSize;
1156  if (dynamicConcatDim) {
1157  // One or more operands have dynamic size, so we must accumulate the
1158  // offset with arith ops.
1159  dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
1160  }
1161 
1162  for (auto operand : concatOp.getInputs()) {
1163  // Get the buffer for the operand.
1164  FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
1165  if (failed(srcBuffer))
1166  return failure();
1167 
1168  // Each operand may have a different size along the concat dimension,
1169  // so the offset on that axis must accumulate through the loop, and the
1170  // size must change to the size of the current operand.
1171  auto operandTensorType = cast<RankedTensorType>(operand.getType());
1172  int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1173 
1174  if (dynamicConcatDim) {
1175  offsets[concatDim] = dynamicOffset.value();
1176  dynamicSize =
1177  memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
1178  .getResult();
1179  sizes[concatDim] = dynamicSize.value();
1180  } else {
1181  sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1182  offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1183  }
1184 
1185  // Create a subview of the destination buffer.
1186  auto dstMemrefType = cast<MemRefType>(memrefType);
1187  MemRefType subviewMemRefType =
1188  memref::SubViewOp::inferRankReducedResultType(
1189  operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1190  strides);
1191  Value subview = memref::SubViewOp::create(
1192  rewriter, loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1193 
1194  // Copy the source buffer into the destination subview.
1195  if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1196  return failure();
1197 
1198  if (dynamicConcatDim) {
1199  dynamicOffset = arith::AddIOp::create(
1200  rewriter, loc, dynamicOffset.value(), dynamicSize.value());
1201  } else {
1202  concatDimOffset += operandConcatDimSize;
1203  }
1204  }
1205 
1206  replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1207  return success();
1208  }
1209 };
1210 
1211 } // namespace
1212 } // namespace tensor
1213 } // namespace mlir
1214 
1216  DialectRegistry &registry) {
1217  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1218  CastOp::attachInterface<CastOpInterface>(*ctx);
1219  CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1220  ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1221  DimOp::attachInterface<DimOpInterface>(*ctx);
1222  EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1223  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1224  ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1225  ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1226  FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1227  GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1228  InsertOp::attachInterface<InsertOpInterface>(*ctx);
1229  InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1230  PadOp::attachInterface<PadOpInterface>(*ctx);
1231  ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1232  *ctx);
1233  RankOp::attachInterface<RankOpInterface>(*ctx);
1234  ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1235  SplatOp::attachInterface<SplatOpInterface>(*ctx);
1236 
1237  // Load additional dialects of which ops may get created.
1238  ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1239  });
1240 
1241  // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1242  // they are registered.
1244 }
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getNumArguments()
Definition: Region.h:123
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...