MLIR  21.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/Operation.h"
25 
26 using namespace mlir;
27 using namespace mlir::bufferization;
28 using namespace mlir::tensor;
29 
30 namespace mlir {
31 namespace tensor {
32 namespace {
33 
34 struct CastOpInterface
35  : public BufferizableOpInterface::ExternalModel<CastOpInterface,
36  tensor::CastOp> {
37  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
38  const AnalysisState &state) const {
39  return false;
40  }
41 
42  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
43  const AnalysisState &state) const {
44  return false;
45  }
46 
47  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
48  const AnalysisState &state) const {
49  return {{op->getResult(0), BufferRelation::Equivalent}};
50  }
51 
52  FailureOr<BufferLikeType>
54  const BufferizationState &state,
55  SmallVector<Value> &invocationStack) const {
56  auto castOp = cast<tensor::CastOp>(op);
57  auto maybeSrcBufferType =
59  castOp.getSource(), options, state, invocationStack));
60  if (failed(maybeSrcBufferType))
61  return failure();
62  Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
63 
64  // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
65  // type in case the input is an unranked tensor type.
66 
67  // Case 1: Casting an unranked tensor
68  if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
69  // When casting to a ranked tensor, we cannot infer any static offset or
70  // strides from the source. Assume fully dynamic.
71  return cast<BufferLikeType>(
72  getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
73  }
74 
75  // Case 2: Casting to an unranked tensor type
76  if (isa<UnrankedTensorType>(castOp.getType())) {
77  return cast<BufferLikeType>(
78  getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
79  }
80 
81  // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
82  // change.
83  auto rankedResultType = cast<RankedTensorType>(castOp.getType());
84  return cast<BufferLikeType>(MemRefType::get(
85  rankedResultType.getShape(), rankedResultType.getElementType(),
86  llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
87  }
88 
89  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
91  BufferizationState &state) const {
92  auto castOp = cast<tensor::CastOp>(op);
93 
94  // The result buffer still has the old (pre-cast) type.
95  FailureOr<Value> resultBuffer =
96  getBuffer(rewriter, castOp.getSource(), options, state);
97  if (failed(resultBuffer))
98  return failure();
99 
100  // Compute the new type.
101  auto resultMemRefType =
102  bufferization::getBufferType(castOp.getResult(), options, state);
103  if (failed(resultMemRefType))
104  return failure();
105  if (resultBuffer->getType() == *resultMemRefType) {
106  // This cast is a no-op.
107  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
108  return success();
109  }
110 
111  // Replace the op with a memref.cast.
112  assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
113  *resultMemRefType) &&
114  "CallOp::bufferize: cast incompatible");
115  replaceOpWithNewBufferizedOp<memref::CastOp>(
116  rewriter, op, *resultMemRefType, *resultBuffer);
117 
118  return success();
119  }
120 };
121 
122 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
123 struct CollapseShapeOpInterface
124  : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
125  tensor::CollapseShapeOp> {
126  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
127  const AnalysisState &state) const {
128  // tensor.collapse_shape may reallocate, at which point the source buffer is
129  // copied. I.e., there will be a memory read side effect on the bufferized
130  // source. This function conservatively returns "true" because whether a
131  // copy will be created or not is not known at this point.
132  return true;
133  }
134 
135  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
136  const AnalysisState &state) const {
137  return false;
138  }
139 
140  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
141  const AnalysisState &state) const {
142  // TODO: CollapseShapeOp may allocate at runtime.
143  return {{op->getOpResult(0), BufferRelation::Equivalent}};
144  }
145 
146  FailureOr<BufferLikeType>
148  const BufferizationState &state,
149  SmallVector<Value> &invocationStack) const {
150  auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
151  auto maybeSrcBufferType = bufferization::getBufferType(
152  collapseShapeOp.getSrc(), options, state, invocationStack);
153  if (failed(maybeSrcBufferType))
154  return failure();
155  auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
156  bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
157  srcBufferType, collapseShapeOp.getReassociationIndices());
158 
159  if (!canBeCollapsed) {
160  // If dims cannot be collapsed, this op bufferizes to a new allocation.
161  RankedTensorType tensorResultType = collapseShapeOp.getResultType();
162  return cast<BufferLikeType>(
164  tensorResultType, srcBufferType.getMemorySpace()));
165  }
166 
167  return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
168  srcBufferType, collapseShapeOp.getReassociationIndices()));
169  }
170 
171  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
173  BufferizationState &state) const {
174  auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
175  RankedTensorType tensorResultType = collapseShapeOp.getResultType();
176  FailureOr<Value> maybeBuffer =
177  getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
178  if (failed(maybeBuffer))
179  return failure();
180  Value buffer = *maybeBuffer;
181  auto bufferType = cast<MemRefType>(buffer.getType());
182 
183  if (tensorResultType.getRank() == 0) {
184  // 0-d collapses must go through a different op builder.
185  MemRefType resultType;
186 
187  if (bufferType.getLayout().isIdentity()) {
188  // Standard layout: result type has no offset.
189  MemRefLayoutAttrInterface layout;
190  resultType = MemRefType::get({}, tensorResultType.getElementType(),
191  layout, bufferType.getMemorySpace());
192  } else {
193  // Source memref has a layout map: result type has the same offset as
194  // the source type.
195  SmallVector<int64_t> strides;
196  int64_t offset;
197  if (failed(bufferType.getStridesAndOffset(strides, offset)))
198  return failure();
199  resultType = MemRefType::get(
200  {}, tensorResultType.getElementType(),
201  StridedLayoutAttr::get(op->getContext(), offset, {}),
202  bufferType.getMemorySpace());
203  }
204 
205  replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
206  rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
207  return success();
208  }
209 
210  // If the dims are not collapsible (due to an incompatible source layout
211  // map), force an out-of-place bufferization, i.e., a buffer copy. This
212  // newly allocated buffer will have no layout map and thus be collapsible.
213  bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
214  bufferType, collapseShapeOp.getReassociationIndices());
215  if (!canBeCollapsed) {
216  // TODO: Create alloc_tensor ops during TensorCopyInsertion.
217  AnalysisState analysisState(options);
218  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
219  rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
220  if (failed(tensorAlloc))
221  return failure();
222  auto memrefType =
223  MemRefType::get(collapseShapeOp.getSrcType().getShape(),
224  collapseShapeOp.getSrcType().getElementType(),
225  AffineMap(), bufferType.getMemorySpace());
226  buffer = rewriter.create<bufferization::ToBufferOp>(
227  op->getLoc(), memrefType, *tensorAlloc);
228  }
229 
230  // Result type is inferred by the builder.
231  replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
232  rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
233  return success();
234  }
235 };
236 
237 /// Bufferization of tensor.dim. Replace with memref.dim.
238 struct DimOpInterface
239  : public BufferizableOpInterface::ExternalModel<DimOpInterface,
240  tensor::DimOp> {
241  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
242  const AnalysisState &state) const {
243  // The op reads the tensor's metadata but not its contents.
244  return false;
245  }
246 
247  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
248  const AnalysisState &state) const {
249  return false;
250  }
251 
252  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
253  const AnalysisState &state) const {
254  return {};
255  }
256 
257  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
259  BufferizationState &state) const {
260  auto dimOp = cast<tensor::DimOp>(op);
261  FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
262  if (failed(v))
263  return failure();
264  replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
265  dimOp.getIndex());
266  return success();
267  }
268 };
269 
270 /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
271 struct EmptyOpInterface
272  : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
273  tensor::EmptyOp> {
274  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
275 
276  bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
277  const AnalysisState &state) const {
278  // The returned tensor does not have specified contents.
279  return false;
280  }
281 
282  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
284  BufferizationState &state) const {
285  auto emptyOp = cast<tensor::EmptyOp>(op);
286 
287  // Optimization: Fold away the op if it has no uses.
288  if (op->getUses().empty()) {
289  rewriter.eraseOp(op);
290  return success();
291  }
292 
293  // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
294  FailureOr<Value> allocTensor = allocateTensorForShapedValue(
295  rewriter, op->getLoc(), emptyOp.getResult(), options, state,
296  /*copy=*/false);
297  if (failed(allocTensor))
298  return failure();
299  rewriter.replaceOp(op, *allocTensor);
300  return success();
301  }
302 };
303 
304 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
305 struct ExpandShapeOpInterface
306  : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
307  tensor::ExpandShapeOp> {
308  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
309  const AnalysisState &state) const {
310  // In contrast to tensor.collapse_shape, this op can always be bufferized
311  // without a copy.
312  return false;
313  }
314 
315  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
316  const AnalysisState &state) const {
317  return false;
318  }
319 
320  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
321  const AnalysisState &state) const {
322  return {{op->getOpResult(0), BufferRelation::Equivalent}};
323  }
324 
325  FailureOr<BufferLikeType>
327  const BufferizationState &state,
328  SmallVector<Value> &invocationStack) const {
329  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
330  auto maybeSrcBufferType = bufferization::getBufferType(
331  expandShapeOp.getSrc(), options, state, invocationStack);
332  if (failed(maybeSrcBufferType))
333  return failure();
334  auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
335  auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
336  srcBufferType, expandShapeOp.getResultType().getShape(),
337  expandShapeOp.getReassociationIndices());
338  if (failed(maybeResultType))
339  return failure();
340  return cast<BufferLikeType>(*maybeResultType);
341  }
342 
343  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
345  BufferizationState &state) const {
346  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
347  auto tensorResultType = expandShapeOp.getResultType();
348  FailureOr<Value> buffer =
349  getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
350  if (failed(buffer))
351  return failure();
352 
353  auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
354  op->getLoc(), tensorResultType.getShape(), *buffer,
355  expandShapeOp.getReassociationIndices(),
356  expandShapeOp.getMixedOutputShape());
357  replaceOpWithBufferizedValues(rewriter, op,
358  memrefExpandShape->getResults());
359  return success();
360  }
361 };
362 
363 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
364 struct ExtractSliceOpInterface
365  : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
366  tensor::ExtractSliceOp> {
367  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
368  const AnalysisState &state) const {
369  return false;
370  }
371 
372  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
373  const AnalysisState &state) const {
374  return false;
375  }
376 
377  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
378  const AnalysisState &state) const {
379  return {{op->getOpResult(0), BufferRelation::Unknown}};
380  }
381 
382  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
384  BufferizationState &state) const {
385  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
386  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
387  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
388  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
389  Location loc = extractSliceOp.getLoc();
390 
391  // Get source buffer.
392  FailureOr<Value> srcMemref =
393  getBuffer(rewriter, extractSliceOp.getSource(), options, state);
394  if (failed(srcMemref))
395  return failure();
396 
397  // Take a subview of the source buffer.
398  auto resultMemrefType = bufferization::getBufferType(
399  extractSliceOp.getResult(), options, state);
400  if (failed(resultMemrefType))
401  return failure();
402  Value subView = rewriter.create<memref::SubViewOp>(
403  loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
404  mixedOffsets, mixedSizes, mixedStrides);
405 
406  replaceOpWithBufferizedValues(rewriter, op, subView);
407  return success();
408  }
409 
410  FailureOr<BufferLikeType>
412  const BufferizationState &state,
413  SmallVector<Value> &invocationStack) const {
414  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
415  assert(value == extractSliceOp.getResult() && "invalid value");
416  auto srcMemrefType = bufferization::getBufferType(
417  extractSliceOp.getSource(), options, state, invocationStack);
418  if (failed(srcMemrefType))
419  return failure();
420  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
421  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
422  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
423  return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
424  extractSliceOp.getType().getShape(),
425  llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
426  mixedStrides));
427  }
428 };
429 
430 /// Bufferization of tensor.extract. Replace with memref.load.
431 struct ExtractOpInterface
432  : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
433  tensor::ExtractOp> {
434  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
435  const AnalysisState &state) const {
436  return true;
437  }
438 
439  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
440  const AnalysisState &state) const {
441  return false;
442  }
443 
444  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
445  const AnalysisState &state) const {
446  return {};
447  }
448 
449  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
451  BufferizationState &state) const {
452  auto extractOp = cast<tensor::ExtractOp>(op);
453  FailureOr<Value> srcMemref =
454  getBuffer(rewriter, extractOp.getTensor(), options, state);
455  if (failed(srcMemref))
456  return failure();
457  replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
458  extractOp.getIndices());
459  return success();
460  }
461 };
462 
463 // Implements backtracking to traverse indices of the output buffer while
464 // iterating over op.elements().
465 static void createStores(RewriterBase &rewriter, Location loc, int dim,
466  Value buffer, ArrayRef<int64_t> shape,
467  ArrayRef<Value> constants,
468  OperandRange::iterator &elementIt,
469  SmallVectorImpl<Value> &indices) {
470  if (dim == static_cast<int>(shape.size()) - 1) {
471  for (int i = 0; i < shape.back(); ++i) {
472  indices.back() = constants[i];
473  rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
474  ++elementIt;
475  }
476  return;
477  }
478  for (int i = 0; i < shape[dim]; ++i) {
479  indices[dim] = constants[i];
480  createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
481  indices);
482  }
483 }
484 
485 /// Bufferization of tensor.from_elements.
486 struct FromElementsOpInterface
487  : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
488  tensor::FromElementsOp> {
489 
490  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
491 
492  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
494  BufferizationState &state) const {
495  auto fromElementsOp = cast<tensor::FromElementsOp>(op);
496  auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
497 
498  // Allocate a buffer for the result.
499  Location loc = op->getLoc();
500  auto shape = tensorType.getShape();
501  // TODO: Create alloc_tensor ops during TensorCopyInsertion.
502  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
503  rewriter, loc, fromElementsOp.getResult(), options, state,
504  /*copy=*/false);
505  if (failed(tensorAlloc))
506  return failure();
507  FailureOr<BufferLikeType> memrefType =
508  bufferization::getBufferType(*tensorAlloc, options, state);
509  if (failed(memrefType))
510  return failure();
511  Value buffer = rewriter.create<bufferization::ToBufferOp>(
512  op->getLoc(), *memrefType, *tensorAlloc);
513 
514  // Case: tensor<0xelem_type>.
515  if (fromElementsOp.getElements().empty()) {
516  replaceOpWithBufferizedValues(rewriter, op, buffer);
517  return success();
518  }
519 
520  // Case: tensor<elem_type>.
521  if (shape.empty()) {
522  rewriter.create<memref::StoreOp>(
523  loc, fromElementsOp.getElements().front(), buffer);
524  replaceOpWithBufferizedValues(rewriter, op, buffer);
525  return success();
526  }
527 
528  // Create constants for the range of possible indices [0, max{shape_i}).
529  auto maxDim = *llvm::max_element(shape);
530  SmallVector<Value, 2> constants;
531  constants.reserve(maxDim);
532  for (int i = 0; i < maxDim; ++i)
533  constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
534 
535  // Traverse all `elements` and create `memref.store` ops.
536  auto elementIt = fromElementsOp.getElements().begin();
537  SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
538  createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
539  indices);
540 
541  replaceOpWithBufferizedValues(rewriter, op, buffer);
542 
543  return success();
544  }
545 };
546 
547 /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
548 /// Such ops are lowered to linalg.map with the given tensor as a destination.
549 ///
550 /// Example:
551 /// ```
552 /// %r = tensor.generate %x, %y {
553 /// ^bb0(%arg0: index, %arg1: index):
554 /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
555 /// tensor.yield %0 : index
556 /// } : tensor<?x?xindex>
557 /// ```
558 ///
559 /// Is lowered to:
560 /// ```
561 /// linalg.map ins() outs(%dest) {
562 /// %d0 = linalg.index 0 : index
563 /// %d1 = linalg.index 1 : index
564 /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
565 /// linalg.yield %0 : index
566 /// }
567 /// ```
568 static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
569  Value tensorDestination,
570  ValueRange dynamicSizes,
571  Region &generateBody) {
572  assert(generateBody.hasOneBlock() && "expected body with single block");
573  auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
574  assert(generateBody.getNumArguments() == tensorType.getRank() &&
575  "rank mismatch");
576 
577  // Create linalg::MapOp.
578  OpBuilder::InsertionGuard g(rewriter);
579  auto linalgOp =
580  rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
581  /*init=*/tensorDestination);
582  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
583 
584  // Create linalg::IndexOps.
585  rewriter.setInsertionPointToStart(&linalgBody);
586  SmallVector<Value> indices;
587  for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
588  indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
589 
590  // Move over body.
591  rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
592  auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
593  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
594 
595  return linalgOp.getResult()[0];
596 }
597 
598 /// Bufferization of tensor.generate.
599 struct GenerateOpInterface
600  : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
601  tensor::GenerateOp> {
602 
603  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
604 
605  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
607  BufferizationState &state) const {
608  auto generateOp = cast<tensor::GenerateOp>(op);
609 
610  auto type = generateOp.getResult().getType();
611 
612  // TODO: Implement memory space for this op.
613  if (options.defaultMemorySpaceFn(type) != Attribute())
614  return op->emitError("memory space not implemented yet");
615 
616  // Allocate memory.
617  Location loc = op->getLoc();
618  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
619  rewriter, loc, generateOp.getResult(), options, state,
620  /*copy=*/false);
621  if (failed(tensorAlloc))
622  return failure();
623 
624  Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
625  generateOp.getDynamicExtents(),
626  generateOp.getBody());
627  rewriter.replaceOp(generateOp, result);
628 
629  return success();
630  }
631 };
632 
633 /// Bufferization of tensor.insert. Replace with memref.store.
634 ///
635 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
636 /// implementations for DestinationStyle ops.
637 struct InsertOpInterface
638  : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
639  tensor::InsertOp> {
640  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
642  BufferizationState &state) const {
643  auto insertOp = cast<tensor::InsertOp>(op);
644  FailureOr<Value> destMemref =
645  getBuffer(rewriter, insertOp.getDest(), options, state);
646  if (failed(destMemref))
647  return failure();
648  rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
649  *destMemref, insertOp.getIndices());
650  replaceOpWithBufferizedValues(rewriter, op, *destMemref);
651  return success();
652  }
653 };
654 
655 template <typename InsertOpTy>
656 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
657  OpOperand &opOperand) {
658  // The source is always read.
659  if (opOperand == insertSliceOp.getSourceMutable())
660  return true;
661 
662  // For the destination, it depends...
663  assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
664 
665  // Dest is not read if it is entirely overwritten. E.g.:
666  // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
667  bool allOffsetsZero =
668  llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
669  RankedTensorType destType = insertSliceOp.getDestType();
670  bool sizesMatchDestSizes =
671  areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
672  bool allStridesOne =
673  areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
674  return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
675 }
676 
677 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
678 /// certain circumstances, this op can also be a no-op.
679 ///
680 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
681 /// implementations for DestinationStyle ops.
682 struct InsertSliceOpInterface
683  : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
684  tensor::InsertSliceOp> {
685  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
686  const AnalysisState &state) const {
687  return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
688  opOperand);
689  }
690 
691  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
693  BufferizationState &state) const {
694  // insert_slice ops arise from tiling and bufferizing them out-of-place is
695  // generally a deal breaker. When used with loops, this ends up cloning the
696  // whole tensor on every single iteration and is a symptom of a
697  // catastrophically bad scheduling decision.
698  // TODO: be very loud about it or even consider failing the pass.
699  auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
700  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
701  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
702  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
703  Location loc = insertSliceOp.getLoc();
704 
705  // Get destination buffer.
706  FailureOr<Value> dstMemref =
707  getBuffer(rewriter, insertSliceOp.getDest(), options, state);
708  if (failed(dstMemref))
709  return failure();
710 
711  // Take a subview of the destination buffer.
712  auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
713  MemRefType subviewMemRefType =
714  memref::SubViewOp::inferRankReducedResultType(
715  insertSliceOp.getSourceType().getShape(), dstMemrefType,
716  mixedOffsets, mixedSizes, mixedStrides);
717  Value subView = rewriter.create<memref::SubViewOp>(
718  loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
719  mixedStrides);
720 
721  // Copy tensor. If this tensor.insert_slice has a matching
722  // tensor.extract_slice, the copy operation will eventually fold away.
723  FailureOr<Value> srcMemref =
724  getBuffer(rewriter, insertSliceOp.getSource(), options, state);
725  if (failed(srcMemref))
726  return failure();
727  if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
728  return failure();
729 
730  replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
731  return success();
732  }
733 };
734 
735 /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
736 /// linalg.map + insert_slice.
737 /// For best performance, vectorize before bufferization (better performance in
738 /// case of padding with a constant).
739 struct PadOpInterface
740  : public BufferizableOpInterface::ExternalModel<PadOpInterface,
741  tensor::PadOp> {
742  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
743 
744  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
745  const AnalysisState &state) const {
746  return true;
747  }
748 
749  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
750  const AnalysisState &state) const {
751  return false;
752  }
753 
754  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
755  const AnalysisState &state) const {
756  return {};
757  }
758 
759  FailureOr<BufferLikeType>
761  const BufferizationState &state,
762  SmallVector<Value> &invocationStack) const {
763  // Infer memory space from the source tensor.
764  auto padOp = cast<tensor::PadOp>(op);
765  auto maybeSrcBufferType =
767  padOp.getSource(), options, state, invocationStack));
768  if (failed(maybeSrcBufferType))
769  return failure();
770  MemRefLayoutAttrInterface layout;
771  return cast<BufferLikeType>(
772  MemRefType::get(padOp.getResultType().getShape(),
773  padOp.getResultType().getElementType(), layout,
774  maybeSrcBufferType->getMemorySpace()));
775  }
776 
777  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
779  BufferizationState &state) const {
780  auto padOp = cast<tensor::PadOp>(op);
781  Location loc = padOp.getLoc();
782  RankedTensorType resultType = padOp.getResultType();
783  RankedTensorType srcType = padOp.getSourceType();
784 
785  auto toValue = [&](OpFoldResult ofr) {
786  if (auto value = dyn_cast<Value>(ofr))
787  return value;
788  return rewriter
789  .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
790  .getResult();
791  };
792 
793  // Compute dynamic result dimensions.
794  SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
795  SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
796  SmallVector<Value> dynamicSizes;
797  for (int64_t i = 0; i < resultType.getRank(); ++i) {
798  if (!resultType.isDynamicDim(i))
799  continue;
800  Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
801  Value lowPad = toValue(mixedLowPad[i]);
802  Value highPad = toValue(mixedHighPad[i]);
803  AffineExpr s0, s1, s2;
804  bindSymbols(op->getContext(), s0, s1, s2);
805  AffineExpr sumExpr = s0 + s1 + s2;
806  Value sum = rewriter.create<affine::AffineApplyOp>(
807  loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
808  dynamicSizes.push_back(sum);
809  }
810 
811  // Allocate a buffer for the padded result.
812  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
813  rewriter, loc, padOp.getResult(), options, state,
814  /*copy=*/false);
815  if (failed(tensorAlloc))
816  return failure();
817 
818  // tensor::PadOp is like tensor::GenerateOp: The only difference is that
819  // only a part of the generated tensor is needed. For simplicity, we reuse
820  // the same functionality here.
821  Value filledBuffer = lowerGenerateLikeOpBody(
822  rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
823 
824  // Create tensor::InsertSliceOp.
825  SmallVector<OpFoldResult> sliceSizes =
826  getMixedSizes(rewriter, loc, padOp.getSource());
827  SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
828  rewriter.getIndexAttr(1));
829  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
830  padOp, padOp.getSource(), filledBuffer,
831  /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
832 
833  return success();
834  }
835 };
836 
837 /// Bufferization of tensor.rank. Replace with memref.rank.
838 struct RankOpInterface
839  : public BufferizableOpInterface::ExternalModel<RankOpInterface,
840  tensor::RankOp> {
841  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
842  const AnalysisState &state) const {
843  // The op reads the tensor's metadata but not its contents.
844  return false;
845  }
846 
847  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
848  const AnalysisState &state) const {
849  return false;
850  }
851 
852  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
853  const AnalysisState &state) const {
854  return {};
855  }
856 
857  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
859  BufferizationState &state) const {
860  auto rankOp = cast<tensor::RankOp>(op);
861  FailureOr<Value> v =
862  getBuffer(rewriter, rankOp.getTensor(), options, state);
863  if (failed(v))
864  return failure();
865  replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
866  *v);
867  return success();
868  }
869 };
870 
871 /// Bufferization of tensor.reshape. Replace with memref.reshape.
872 struct ReshapeOpInterface
873  : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
874  tensor::ReshapeOp> {
875  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
876  const AnalysisState &state) const {
877  // Depending on the layout map, the source buffer may have to be copied.
878  auto reshapeOp = cast<tensor::ReshapeOp>(op);
879  return opOperand == reshapeOp.getShapeMutable();
880  }
881 
882  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
883  const AnalysisState &state) const {
884  return false;
885  }
886 
887  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
888  const AnalysisState &state) const {
889  // Only the 'source' operand aliases the result.
890  auto reshapeOp = cast<tensor::ReshapeOp>(op);
891  if (reshapeOp.getSourceMutable() != opOperand)
892  return {};
893  return {{op->getOpResult(0), BufferRelation::Equivalent}};
894  }
895 
896  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
898  BufferizationState &state) const {
899  auto reshapeOp = cast<tensor::ReshapeOp>(op);
900  FailureOr<Value> srcBuffer =
901  getBuffer(rewriter, reshapeOp.getSource(), options, state);
902  FailureOr<Value> shapeBuffer =
903  getBuffer(rewriter, reshapeOp.getShape(), options, state);
904  if (failed(srcBuffer) || failed(shapeBuffer))
905  return failure();
906  auto maybeResultMemRefType =
907  bufferization::getBufferType(reshapeOp.getResult(), options, state);
908  if (failed(maybeResultMemRefType))
909  return failure();
910 
911  // memref.reshape requires the source buffer to have an identity layout.
912  // If the source memref does not have an identity layout, copy the source
913  // into a new buffer with an identity layout.
914  auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
915  if (srcType && !srcType.getLayout().isIdentity()) {
916  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
917  rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
918  if (failed(tensorAlloc))
919  return failure();
920  auto memrefType = MemRefType::get(
921  srcType.getShape(), srcType.getElementType(), AffineMap(),
922  cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
923  srcBuffer = rewriter
924  .create<bufferization::ToBufferOp>(
925  op->getLoc(), memrefType, *tensorAlloc)
926  .getResult();
927  }
928 
929  replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
930  rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
931  return success();
932  }
933 
934  FailureOr<BufferLikeType>
936  const BufferizationState &state,
937  SmallVector<Value> &invocationStack) const {
938  auto reshapeOp = cast<tensor::ReshapeOp>(op);
939  assert(value == reshapeOp.getResult() && "unexpected value provided");
940  auto maybeSourceBufferType = bufferization::getBufferType(
941  reshapeOp.getSource(), options, state, invocationStack);
942  if (failed(maybeSourceBufferType))
943  return failure();
944  return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
945  reshapeOp.getResult().getType(),
946  cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
947  }
948 };
949 
950 /// Analysis of ParallelInsertSliceOp.
951 struct ParallelInsertSliceOpInterface
952  : public BufferizableOpInterface::ExternalModel<
953  ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
954  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
955  const AnalysisState &state) const {
956  return {};
957  }
958 
959  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
960  const AnalysisState &state) const {
961  return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
962  }
963 
964  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
965  const AnalysisState &state) const {
966  auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
967  return opOperand == parallelInsertSliceOp.getDestMutable();
968  }
969 
970  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
972  BufferizationState &state) const {
973  OpBuilder::InsertionGuard g(rewriter);
974  auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
975  ParallelCombiningOpInterface parallelCombiningParent =
976  parallelInsertSliceOp.getParallelCombiningParent();
977 
978  // Bufferize the op outside of the parallel combining terminator.
979  rewriter.setInsertionPoint(parallelCombiningParent);
980 
981  // Get source and destination buffers.
982  FailureOr<Value> destBuffer =
983  getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
984  if (failed(destBuffer))
985  return failure();
986  FailureOr<Value> srcBuffer =
987  getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
988  if (failed(srcBuffer))
989  return failure();
990 
991  // Take a subview of the destination buffer.
992  auto destBufferType = cast<MemRefType>(destBuffer->getType());
993  MemRefType subviewMemRefType =
994  memref::SubViewOp::inferRankReducedResultType(
995  parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
996  parallelInsertSliceOp.getMixedOffsets(),
997  parallelInsertSliceOp.getMixedSizes(),
998  parallelInsertSliceOp.getMixedStrides());
999  Value subview = rewriter.create<memref::SubViewOp>(
1000  parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
1001  parallelInsertSliceOp.getMixedOffsets(),
1002  parallelInsertSliceOp.getMixedSizes(),
1003  parallelInsertSliceOp.getMixedStrides());
1004 
1005  // This memcpy will fold away if everything bufferizes in-place.
1006  if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
1007  *srcBuffer, subview)))
1008  return failure();
1009 
1010  // In case the source was allocated in the same block, make sure that the
1011  // deallocation op (if any) appears after the memcpy. By default, deallocs
1012  // are placed before the terminator, but this does not work for ForallOp
1013  // because the terminator does more than just yielding a value.
1014  //
1015  // Note: This is not a problem for the destination buffer because these are
1016  // assumed to always bufferize in-place.
1017  for (Operation *user : srcBuffer->getUsers()) {
1018  if (hasEffect<MemoryEffects::Free>(user)) {
1019  if (user->getBlock() == parallelCombiningParent->getBlock())
1020  rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
1021  break;
1022  }
1023  }
1024 
1025  // Delete the op.
1026  rewriter.eraseOp(op);
1027  return success();
1028  }
1029 
1030  /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1031  /// shouldn't create copy to resolve conflict.
1032  LogicalResult
1033  resolveConflicts(Operation *op, RewriterBase &rewriter,
1034  const AnalysisState &analysisState,
1035  const BufferizationState &bufferizationState) const {
1036  return success();
1037  }
1038 };
1039 
1040 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1041 /// with a linalg.map. Similar to tensor.generate.
1042 struct SplatOpInterface
1043  : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1044  tensor::SplatOp> {
1045 
1046  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1047 
1048  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1050  BufferizationState &state) const {
1051  OpBuilder::InsertionGuard g(rewriter);
1052  auto splatOp = cast<tensor::SplatOp>(op);
1053 
1054  // Allocate memory.
1055  Location loc = op->getLoc();
1056  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1057  rewriter, loc, splatOp.getResult(), options, state,
1058  /*copy=*/false);
1059  if (failed(tensorAlloc))
1060  return failure();
1061 
1062  // Create linalg::MapOp.
1063  auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1064 
1065  // TODO: Implement memory space for this op.
1066  if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1067  return op->emitError("memory space not implemented yet");
1068 
1069  auto linalgOp =
1070  rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
1071  /*init=*/*tensorAlloc);
1072  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1073 
1074  // Create linalg::IndexOps.
1075  rewriter.setInsertionPointToStart(&linalgBody);
1076  rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
1077  rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1078 
1079  return success();
1080  }
1081 };
1082 
1083 /// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1084 /// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1085 /// on subviews instead of memref.store.
1086 struct ConcatOpInterface
1087  : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1088  tensor::ConcatOp> {
1089 
1090  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1091 
1092  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1093  const AnalysisState &state) const {
1094  return false;
1095  }
1096 
1097  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1098  const AnalysisState &state) const {
1099  return true;
1100  }
1101 
1102  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1103  const AnalysisState &state) const {
1104  return {};
1105  }
1106 
1107  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1109  BufferizationState &state) const {
1110  OpBuilder::InsertionGuard g(rewriter);
1111  auto concatOp = cast<tensor::ConcatOp>(op);
1112 
1113  // Allocate memory.
1114  Location loc = op->getLoc();
1115  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1116  rewriter, loc, concatOp.getResult(), options, state,
1117  /*copy=*/false);
1118  if (failed(tensorAlloc))
1119  return failure();
1120  auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1121 
1122  // TODO: Implement memory space for this op.
1123  if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1124  return op->emitError("memory space not implemented yet");
1125 
1126  MemRefLayoutAttrInterface layout;
1127  MemRefType memrefType =
1128  MemRefType::get(concatOp.getResultType().getShape(),
1129  concatOp.getResultType().getElementType(), layout);
1130  Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
1131  op->getLoc(), memrefType, *tensorAlloc);
1132 
1133  // Extract the dimension for the concat op
1134  uint64_t concatDim = concatOp.getDim();
1135  bool dynamicConcatDim = false;
1136 
1137  SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1138  rewriter.getIndexAttr(0));
1139  SmallVector<OpFoldResult> strides(tensorType.getRank(),
1140  rewriter.getIndexAttr(1));
1142 
1143  for (const auto &[dimIdx, dimSize] :
1144  llvm::enumerate(tensorType.getShape())) {
1145  if (dimSize == ShapedType::kDynamic) {
1146  auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
1147  sizes.push_back(dimOp.getResult());
1148  if (dimIdx == concatDim)
1149  dynamicConcatDim = true;
1150  } else {
1151  sizes.push_back(rewriter.getIndexAttr(dimSize));
1152  }
1153  }
1154 
1155  int64_t concatDimOffset = 0;
1156  std::optional<Value> dynamicOffset;
1157  std::optional<Value> dynamicSize;
1158  if (dynamicConcatDim) {
1159  // One or more operands have dynamic size, so we must accumulate the
1160  // offset with arith ops.
1161  dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1162  }
1163 
1164  for (auto operand : concatOp.getInputs()) {
1165  // Get the buffer for the operand.
1166  FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
1167  if (failed(srcBuffer))
1168  return failure();
1169 
1170  // Each operand may have a different size along the concat dimension,
1171  // so the offset on that axis must accumulate through the loop, and the
1172  // size must change to the size of the current operand.
1173  auto operandTensorType = cast<RankedTensorType>(operand.getType());
1174  int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1175 
1176  if (dynamicConcatDim) {
1177  offsets[concatDim] = dynamicOffset.value();
1178  dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1179  .getResult();
1180  sizes[concatDim] = dynamicSize.value();
1181  } else {
1182  sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1183  offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1184  }
1185 
1186  // Create a subview of the destination buffer.
1187  auto dstMemrefType = cast<MemRefType>(memrefType);
1188  MemRefType subviewMemRefType =
1189  memref::SubViewOp::inferRankReducedResultType(
1190  operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1191  strides);
1192  Value subview = rewriter.create<memref::SubViewOp>(
1193  loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1194 
1195  // Copy the source buffer into the destination subview.
1196  if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1197  return failure();
1198 
1199  if (dynamicConcatDim) {
1200  dynamicOffset = rewriter.create<arith::AddIOp>(
1201  loc, dynamicOffset.value(), dynamicSize.value());
1202  } else {
1203  concatDimOffset += operandConcatDimSize;
1204  }
1205  }
1206 
1207  replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1208  return success();
1209  }
1210 };
1211 
1212 } // namespace
1213 } // namespace tensor
1214 } // namespace mlir
1215 
1217  DialectRegistry &registry) {
1218  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1219  CastOp::attachInterface<CastOpInterface>(*ctx);
1220  CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1221  ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1222  DimOp::attachInterface<DimOpInterface>(*ctx);
1223  EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1224  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1225  ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1226  ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1227  FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1228  GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1229  InsertOp::attachInterface<InsertOpInterface>(*ctx);
1230  InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1231  PadOp::attachInterface<PadOpInterface>(*ctx);
1232  ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1233  *ctx);
1234  RankOp::attachInterface<RankOpInterface>(*ctx);
1235  ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1236  SplatOp::attachInterface<SplatOpInterface>(*ctx);
1237 
1238  // Load additional dialects of which ops may get created.
1239  ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1240  });
1241 
1242  // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1243  // they are registered.
1245 }
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getNumArguments()
Definition: Region.h:123
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...