MLIR  21.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/Operation.h"
25 
26 using namespace mlir;
27 using namespace mlir::bufferization;
28 using namespace mlir::tensor;
29 
30 namespace mlir {
31 namespace tensor {
32 namespace {
33 
34 struct CastOpInterface
35  : public BufferizableOpInterface::ExternalModel<CastOpInterface,
36  tensor::CastOp> {
37  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
38  const AnalysisState &state) const {
39  return false;
40  }
41 
42  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
43  const AnalysisState &state) const {
44  return false;
45  }
46 
47  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
48  const AnalysisState &state) const {
49  return {{op->getResult(0), BufferRelation::Equivalent}};
50  }
51 
52  FailureOr<BaseMemRefType>
54  const BufferizationState &state,
55  SmallVector<Value> &invocationStack) const {
56  auto castOp = cast<tensor::CastOp>(op);
57  auto maybeSrcBufferType = bufferization::getBufferType(
58  castOp.getSource(), options, state, invocationStack);
59  if (failed(maybeSrcBufferType))
60  return failure();
61  Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
62 
63  // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
64  // type in case the input is an unranked tensor type.
65 
66  // Case 1: Casting an unranked tensor
67  if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
68  // When casting to a ranked tensor, we cannot infer any static offset or
69  // strides from the source. Assume fully dynamic.
70  return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
71  }
72 
73  // Case 2: Casting to an unranked tensor type
74  if (isa<UnrankedTensorType>(castOp.getType())) {
75  return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
76  }
77 
78  // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
79  // change.
80  auto rankedResultType = cast<RankedTensorType>(castOp.getType());
81  return MemRefType::get(
82  rankedResultType.getShape(), rankedResultType.getElementType(),
83  llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
84  }
85 
86  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88  BufferizationState &state) const {
89  auto castOp = cast<tensor::CastOp>(op);
90 
91  // The result buffer still has the old (pre-cast) type.
92  FailureOr<Value> resultBuffer =
93  getBuffer(rewriter, castOp.getSource(), options, state);
94  if (failed(resultBuffer))
95  return failure();
96 
97  // Compute the new type.
98  auto resultMemRefType =
99  bufferization::getBufferType(castOp.getResult(), options, state);
100  if (failed(resultMemRefType))
101  return failure();
102  if (resultBuffer->getType() == *resultMemRefType) {
103  // This cast is a no-op.
104  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
105  return success();
106  }
107 
108  // Replace the op with a memref.cast.
109  assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
110  *resultMemRefType) &&
111  "CallOp::bufferize: cast incompatible");
112  replaceOpWithNewBufferizedOp<memref::CastOp>(
113  rewriter, op, *resultMemRefType, *resultBuffer);
114 
115  return success();
116  }
117 };
118 
119 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
120 struct CollapseShapeOpInterface
121  : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
122  tensor::CollapseShapeOp> {
123  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
124  const AnalysisState &state) const {
125  // tensor.collapse_shape may reallocate, at which point the source buffer is
126  // copied. I.e., there will be a memory read side effect on the bufferized
127  // source. This function conservatively returns "true" because whether a
128  // copy will be created or not is not known at this point.
129  return true;
130  }
131 
132  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
133  const AnalysisState &state) const {
134  return false;
135  }
136 
137  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
138  const AnalysisState &state) const {
139  // TODO: CollapseShapeOp may allocate at runtime.
140  return {{op->getOpResult(0), BufferRelation::Equivalent}};
141  }
142 
143  FailureOr<BaseMemRefType>
145  const BufferizationState &state,
146  SmallVector<Value> &invocationStack) const {
147  auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
148  auto maybeSrcBufferType = bufferization::getBufferType(
149  collapseShapeOp.getSrc(), options, state, invocationStack);
150  if (failed(maybeSrcBufferType))
151  return failure();
152  auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
153  bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
154  srcBufferType, collapseShapeOp.getReassociationIndices());
155 
156  if (!canBeCollapsed) {
157  // If dims cannot be collapsed, this op bufferizes to a new allocation.
158  RankedTensorType tensorResultType = collapseShapeOp.getResultType();
160  tensorResultType, srcBufferType.getMemorySpace());
161  }
162 
163  return memref::CollapseShapeOp::computeCollapsedType(
164  srcBufferType, collapseShapeOp.getReassociationIndices());
165  }
166 
167  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
169  BufferizationState &state) const {
170  auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
171  RankedTensorType tensorResultType = collapseShapeOp.getResultType();
172  FailureOr<Value> maybeBuffer =
173  getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
174  if (failed(maybeBuffer))
175  return failure();
176  Value buffer = *maybeBuffer;
177  auto bufferType = cast<MemRefType>(buffer.getType());
178 
179  if (tensorResultType.getRank() == 0) {
180  // 0-d collapses must go through a different op builder.
181  MemRefType resultType;
182 
183  if (bufferType.getLayout().isIdentity()) {
184  // Standard layout: result type has no offset.
185  MemRefLayoutAttrInterface layout;
186  resultType = MemRefType::get({}, tensorResultType.getElementType(),
187  layout, bufferType.getMemorySpace());
188  } else {
189  // Source memref has a layout map: result type has the same offset as
190  // the source type.
191  SmallVector<int64_t> strides;
192  int64_t offset;
193  if (failed(bufferType.getStridesAndOffset(strides, offset)))
194  return failure();
195  resultType = MemRefType::get(
196  {}, tensorResultType.getElementType(),
197  StridedLayoutAttr::get(op->getContext(), offset, {}),
198  bufferType.getMemorySpace());
199  }
200 
201  replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
202  rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
203  return success();
204  }
205 
206  // If the dims are not collapsible (due to an incompatible source layout
207  // map), force an out-of-place bufferization, i.e., a buffer copy. This
208  // newly allocated buffer will have no layout map and thus be collapsible.
209  bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
210  bufferType, collapseShapeOp.getReassociationIndices());
211  if (!canBeCollapsed) {
212  // TODO: Create alloc_tensor ops during TensorCopyInsertion.
213  AnalysisState analysisState(options);
214  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
215  rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
216  if (failed(tensorAlloc))
217  return failure();
218  auto memrefType =
219  MemRefType::get(collapseShapeOp.getSrcType().getShape(),
220  collapseShapeOp.getSrcType().getElementType(),
221  AffineMap(), bufferType.getMemorySpace());
222  buffer = rewriter.create<bufferization::ToBufferOp>(
223  op->getLoc(), memrefType, *tensorAlloc);
224  }
225 
226  // Result type is inferred by the builder.
227  replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
228  rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
229  return success();
230  }
231 };
232 
233 /// Bufferization of tensor.dim. Replace with memref.dim.
234 struct DimOpInterface
235  : public BufferizableOpInterface::ExternalModel<DimOpInterface,
236  tensor::DimOp> {
237  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
238  const AnalysisState &state) const {
239  // The op reads the tensor's metadata but not its contents.
240  return false;
241  }
242 
243  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
244  const AnalysisState &state) const {
245  return false;
246  }
247 
248  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
249  const AnalysisState &state) const {
250  return {};
251  }
252 
253  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
255  BufferizationState &state) const {
256  auto dimOp = cast<tensor::DimOp>(op);
257  FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
258  if (failed(v))
259  return failure();
260  replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
261  dimOp.getIndex());
262  return success();
263  }
264 };
265 
266 /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
267 struct EmptyOpInterface
268  : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
269  tensor::EmptyOp> {
270  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
271 
272  bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
273  const AnalysisState &state) const {
274  // The returned tensor does not have specified contents.
275  return false;
276  }
277 
278  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
280  BufferizationState &state) const {
281  auto emptyOp = cast<tensor::EmptyOp>(op);
282 
283  // Optimization: Fold away the op if it has no uses.
284  if (op->getUses().empty()) {
285  rewriter.eraseOp(op);
286  return success();
287  }
288 
289  // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
290  FailureOr<Value> allocTensor = allocateTensorForShapedValue(
291  rewriter, op->getLoc(), emptyOp.getResult(), options, state,
292  /*copy=*/false);
293  if (failed(allocTensor))
294  return failure();
295  rewriter.replaceOp(op, *allocTensor);
296  return success();
297  }
298 };
299 
300 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
301 struct ExpandShapeOpInterface
302  : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
303  tensor::ExpandShapeOp> {
304  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
305  const AnalysisState &state) const {
306  // In contrast to tensor.collapse_shape, this op can always be bufferized
307  // without a copy.
308  return false;
309  }
310 
311  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
312  const AnalysisState &state) const {
313  return false;
314  }
315 
316  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
317  const AnalysisState &state) const {
318  return {{op->getOpResult(0), BufferRelation::Equivalent}};
319  }
320 
321  FailureOr<BaseMemRefType>
323  const BufferizationState &state,
324  SmallVector<Value> &invocationStack) const {
325  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
326  auto maybeSrcBufferType = bufferization::getBufferType(
327  expandShapeOp.getSrc(), options, state, invocationStack);
328  if (failed(maybeSrcBufferType))
329  return failure();
330  auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
331  auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
332  srcBufferType, expandShapeOp.getResultType().getShape(),
333  expandShapeOp.getReassociationIndices());
334  if (failed(maybeResultType))
335  return failure();
336  return *maybeResultType;
337  }
338 
339  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
341  BufferizationState &state) const {
342  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
343  auto tensorResultType = expandShapeOp.getResultType();
344  FailureOr<Value> buffer =
345  getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
346  if (failed(buffer))
347  return failure();
348 
349  auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
350  op->getLoc(), tensorResultType.getShape(), *buffer,
351  expandShapeOp.getReassociationIndices(),
352  expandShapeOp.getMixedOutputShape());
353  replaceOpWithBufferizedValues(rewriter, op,
354  memrefExpandShape->getResults());
355  return success();
356  }
357 };
358 
359 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
360 struct ExtractSliceOpInterface
361  : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
362  tensor::ExtractSliceOp> {
363  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
364  const AnalysisState &state) const {
365  return false;
366  }
367 
368  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
369  const AnalysisState &state) const {
370  return false;
371  }
372 
373  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
374  const AnalysisState &state) const {
375  return {{op->getOpResult(0), BufferRelation::Unknown}};
376  }
377 
378  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
380  BufferizationState &state) const {
381  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
382  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
383  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
384  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
385  Location loc = extractSliceOp.getLoc();
386 
387  // Get source buffer.
388  FailureOr<Value> srcMemref =
389  getBuffer(rewriter, extractSliceOp.getSource(), options, state);
390  if (failed(srcMemref))
391  return failure();
392 
393  // Take a subview of the source buffer.
394  auto resultMemrefType = bufferization::getBufferType(
395  extractSliceOp.getResult(), options, state);
396  if (failed(resultMemrefType))
397  return failure();
398  Value subView = rewriter.create<memref::SubViewOp>(
399  loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
400  mixedOffsets, mixedSizes, mixedStrides);
401 
402  replaceOpWithBufferizedValues(rewriter, op, subView);
403  return success();
404  }
405 
406  FailureOr<BaseMemRefType>
408  const BufferizationState &state,
409  SmallVector<Value> &invocationStack) const {
410  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
411  assert(value == extractSliceOp.getResult() && "invalid value");
412  auto srcMemrefType = bufferization::getBufferType(
413  extractSliceOp.getSource(), options, state, invocationStack);
414  if (failed(srcMemrefType))
415  return failure();
416  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
417  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
418  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
419  return memref::SubViewOp::inferRankReducedResultType(
420  extractSliceOp.getType().getShape(),
421  llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
422  mixedStrides);
423  }
424 };
425 
426 /// Bufferization of tensor.extract. Replace with memref.load.
427 struct ExtractOpInterface
428  : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
429  tensor::ExtractOp> {
430  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
431  const AnalysisState &state) const {
432  return true;
433  }
434 
435  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
436  const AnalysisState &state) const {
437  return false;
438  }
439 
440  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
441  const AnalysisState &state) const {
442  return {};
443  }
444 
445  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
447  BufferizationState &state) const {
448  auto extractOp = cast<tensor::ExtractOp>(op);
449  FailureOr<Value> srcMemref =
450  getBuffer(rewriter, extractOp.getTensor(), options, state);
451  if (failed(srcMemref))
452  return failure();
453  replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
454  extractOp.getIndices());
455  return success();
456  }
457 };
458 
459 // Implements backtracking to traverse indices of the output buffer while
460 // iterating over op.elements().
461 static void createStores(RewriterBase &rewriter, Location loc, int dim,
462  Value buffer, ArrayRef<int64_t> shape,
463  ArrayRef<Value> constants,
464  OperandRange::iterator &elementIt,
465  SmallVectorImpl<Value> &indices) {
466  if (dim == static_cast<int>(shape.size()) - 1) {
467  for (int i = 0; i < shape.back(); ++i) {
468  indices.back() = constants[i];
469  rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
470  ++elementIt;
471  }
472  return;
473  }
474  for (int i = 0; i < shape[dim]; ++i) {
475  indices[dim] = constants[i];
476  createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
477  indices);
478  }
479 }
480 
481 /// Bufferization of tensor.from_elements.
482 struct FromElementsOpInterface
483  : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
484  tensor::FromElementsOp> {
485 
486  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
487 
488  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
490  BufferizationState &state) const {
491  auto fromElementsOp = cast<tensor::FromElementsOp>(op);
492  auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
493 
494  // Allocate a buffer for the result.
495  Location loc = op->getLoc();
496  auto shape = tensorType.getShape();
497  // TODO: Create alloc_tensor ops during TensorCopyInsertion.
498  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
499  rewriter, loc, fromElementsOp.getResult(), options, state,
500  /*copy=*/false);
501  if (failed(tensorAlloc))
502  return failure();
503  FailureOr<BaseMemRefType> memrefType =
504  bufferization::getBufferType(*tensorAlloc, options, state);
505  if (failed(memrefType))
506  return failure();
507  Value buffer = rewriter.create<bufferization::ToBufferOp>(
508  op->getLoc(), *memrefType, *tensorAlloc);
509 
510  // Case: tensor<0xelem_type>.
511  if (fromElementsOp.getElements().empty()) {
512  replaceOpWithBufferizedValues(rewriter, op, buffer);
513  return success();
514  }
515 
516  // Case: tensor<elem_type>.
517  if (shape.empty()) {
518  rewriter.create<memref::StoreOp>(
519  loc, fromElementsOp.getElements().front(), buffer);
520  replaceOpWithBufferizedValues(rewriter, op, buffer);
521  return success();
522  }
523 
524  // Create constants for the range of possible indices [0, max{shape_i}).
525  auto maxDim = *llvm::max_element(shape);
526  SmallVector<Value, 2> constants;
527  constants.reserve(maxDim);
528  for (int i = 0; i < maxDim; ++i)
529  constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
530 
531  // Traverse all `elements` and create `memref.store` ops.
532  auto elementIt = fromElementsOp.getElements().begin();
533  SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
534  createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
535  indices);
536 
537  replaceOpWithBufferizedValues(rewriter, op, buffer);
538 
539  return success();
540  }
541 };
542 
543 /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
544 /// Such ops are lowered to linalg.map with the given tensor as a destination.
545 ///
546 /// Example:
547 /// ```
548 /// %r = tensor.generate %x, %y {
549 /// ^bb0(%arg0: index, %arg1: index):
550 /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
551 /// tensor.yield %0 : index
552 /// } : tensor<?x?xindex>
553 /// ```
554 ///
555 /// Is lowered to:
556 /// ```
557 /// linalg.map ins() outs(%dest) {
558 /// %d0 = linalg.index 0 : index
559 /// %d1 = linalg.index 1 : index
560 /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
561 /// linalg.yield %0 : index
562 /// }
563 /// ```
564 static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
565  Value tensorDestination,
566  ValueRange dynamicSizes,
567  Region &generateBody) {
568  assert(generateBody.hasOneBlock() && "expected body with single block");
569  auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
570  assert(generateBody.getNumArguments() == tensorType.getRank() &&
571  "rank mismatch");
572 
573  // Create linalg::MapOp.
574  OpBuilder::InsertionGuard g(rewriter);
575  auto linalgOp =
576  rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
577  /*init=*/tensorDestination);
578  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
579 
580  // Create linalg::IndexOps.
581  rewriter.setInsertionPointToStart(&linalgBody);
582  SmallVector<Value> indices;
583  for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
584  indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
585 
586  // Move over body.
587  rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
588  auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
589  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
590 
591  return linalgOp.getResult()[0];
592 }
593 
594 /// Bufferization of tensor.generate.
595 struct GenerateOpInterface
596  : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
597  tensor::GenerateOp> {
598 
599  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
600 
601  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
603  BufferizationState &state) const {
604  auto generateOp = cast<tensor::GenerateOp>(op);
605 
606  auto type = generateOp.getResult().getType();
607 
608  // TODO: Implement memory space for this op.
609  if (options.defaultMemorySpaceFn(type) != Attribute())
610  return op->emitError("memory space not implemented yet");
611 
612  // Allocate memory.
613  Location loc = op->getLoc();
614  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
615  rewriter, loc, generateOp.getResult(), options, state,
616  /*copy=*/false);
617  if (failed(tensorAlloc))
618  return failure();
619 
620  Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
621  generateOp.getDynamicExtents(),
622  generateOp.getBody());
623  rewriter.replaceOp(generateOp, result);
624 
625  return success();
626  }
627 };
628 
629 /// Bufferization of tensor.insert. Replace with memref.store.
630 ///
631 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
632 /// implementations for DestinationStyle ops.
633 struct InsertOpInterface
634  : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
635  tensor::InsertOp> {
636  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
638  BufferizationState &state) const {
639  auto insertOp = cast<tensor::InsertOp>(op);
640  FailureOr<Value> destMemref =
641  getBuffer(rewriter, insertOp.getDest(), options, state);
642  if (failed(destMemref))
643  return failure();
644  rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
645  *destMemref, insertOp.getIndices());
646  replaceOpWithBufferizedValues(rewriter, op, *destMemref);
647  return success();
648  }
649 };
650 
651 template <typename InsertOpTy>
652 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
653  OpOperand &opOperand) {
654  // The source is always read.
655  if (opOperand == insertSliceOp.getSourceMutable())
656  return true;
657 
658  // For the destination, it depends...
659  assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
660 
661  // Dest is not read if it is entirely overwritten. E.g.:
662  // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
663  bool allOffsetsZero =
664  llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
665  RankedTensorType destType = insertSliceOp.getDestType();
666  bool sizesMatchDestSizes =
667  areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
668  bool allStridesOne =
669  areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
670  return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
671 }
672 
673 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
674 /// certain circumstances, this op can also be a no-op.
675 ///
676 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
677 /// implementations for DestinationStyle ops.
678 struct InsertSliceOpInterface
679  : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
680  tensor::InsertSliceOp> {
681  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
682  const AnalysisState &state) const {
683  return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
684  opOperand);
685  }
686 
687  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
689  BufferizationState &state) const {
690  // insert_slice ops arise from tiling and bufferizing them out-of-place is
691  // generally a deal breaker. When used with loops, this ends up cloning the
692  // whole tensor on every single iteration and is a symptom of a
693  // catastrophically bad scheduling decision.
694  // TODO: be very loud about it or even consider failing the pass.
695  auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
696  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
697  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
698  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
699  Location loc = insertSliceOp.getLoc();
700 
701  // Get destination buffer.
702  FailureOr<Value> dstMemref =
703  getBuffer(rewriter, insertSliceOp.getDest(), options, state);
704  if (failed(dstMemref))
705  return failure();
706 
707  // Take a subview of the destination buffer.
708  auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
709  MemRefType subviewMemRefType =
710  memref::SubViewOp::inferRankReducedResultType(
711  insertSliceOp.getSourceType().getShape(), dstMemrefType,
712  mixedOffsets, mixedSizes, mixedStrides);
713  Value subView = rewriter.create<memref::SubViewOp>(
714  loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
715  mixedStrides);
716 
717  // Copy tensor. If this tensor.insert_slice has a matching
718  // tensor.extract_slice, the copy operation will eventually fold away.
719  FailureOr<Value> srcMemref =
720  getBuffer(rewriter, insertSliceOp.getSource(), options, state);
721  if (failed(srcMemref))
722  return failure();
723  if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
724  return failure();
725 
726  replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
727  return success();
728  }
729 };
730 
731 /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
732 /// linalg.map + insert_slice.
733 /// For best performance, vectorize before bufferization (better performance in
734 /// case of padding with a constant).
735 struct PadOpInterface
736  : public BufferizableOpInterface::ExternalModel<PadOpInterface,
737  tensor::PadOp> {
738  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
739 
740  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
741  const AnalysisState &state) const {
742  return true;
743  }
744 
745  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
746  const AnalysisState &state) const {
747  return false;
748  }
749 
750  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
751  const AnalysisState &state) const {
752  return {};
753  }
754 
755  FailureOr<BaseMemRefType>
757  const BufferizationState &state,
758  SmallVector<Value> &invocationStack) const {
759  // Infer memory space from the source tensor.
760  auto padOp = cast<tensor::PadOp>(op);
761  auto maybeSrcBufferType = bufferization::getBufferType(
762  padOp.getSource(), options, state, invocationStack);
763  if (failed(maybeSrcBufferType))
764  return failure();
765  MemRefLayoutAttrInterface layout;
766  return MemRefType::get(padOp.getResultType().getShape(),
767  padOp.getResultType().getElementType(), layout,
768  maybeSrcBufferType->getMemorySpace());
769  }
770 
771  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
773  BufferizationState &state) const {
774  auto padOp = cast<tensor::PadOp>(op);
775  Location loc = padOp.getLoc();
776  RankedTensorType resultType = padOp.getResultType();
777  RankedTensorType srcType = padOp.getSourceType();
778 
779  auto toValue = [&](OpFoldResult ofr) {
780  if (auto value = dyn_cast<Value>(ofr))
781  return value;
782  return rewriter
783  .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
784  .getResult();
785  };
786 
787  // Compute dynamic result dimensions.
788  SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
789  SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
790  SmallVector<Value> dynamicSizes;
791  for (int64_t i = 0; i < resultType.getRank(); ++i) {
792  if (!resultType.isDynamicDim(i))
793  continue;
794  Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
795  Value lowPad = toValue(mixedLowPad[i]);
796  Value highPad = toValue(mixedHighPad[i]);
797  AffineExpr s0, s1, s2;
798  bindSymbols(op->getContext(), s0, s1, s2);
799  AffineExpr sumExpr = s0 + s1 + s2;
800  Value sum = rewriter.create<affine::AffineApplyOp>(
801  loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
802  dynamicSizes.push_back(sum);
803  }
804 
805  // Allocate a buffer for the padded result.
806  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
807  rewriter, loc, padOp.getResult(), options, state,
808  /*copy=*/false);
809  if (failed(tensorAlloc))
810  return failure();
811 
812  // tensor::PadOp is like tensor::GenerateOp: The only difference is that
813  // only a part of the generated tensor is needed. For simplicity, we reuse
814  // the same functionality here.
815  Value filledBuffer = lowerGenerateLikeOpBody(
816  rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
817 
818  // Create tensor::InsertSliceOp.
819  SmallVector<OpFoldResult> sliceSizes =
820  getMixedSizes(rewriter, loc, padOp.getSource());
821  SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
822  rewriter.getIndexAttr(1));
823  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
824  padOp, padOp.getSource(), filledBuffer,
825  /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
826 
827  return success();
828  }
829 };
830 
831 /// Bufferization of tensor.rank. Replace with memref.rank.
832 struct RankOpInterface
833  : public BufferizableOpInterface::ExternalModel<RankOpInterface,
834  tensor::RankOp> {
835  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
836  const AnalysisState &state) const {
837  // The op reads the tensor's metadata but not its contents.
838  return false;
839  }
840 
841  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
842  const AnalysisState &state) const {
843  return false;
844  }
845 
846  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
847  const AnalysisState &state) const {
848  return {};
849  }
850 
851  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
853  BufferizationState &state) const {
854  auto rankOp = cast<tensor::RankOp>(op);
855  FailureOr<Value> v =
856  getBuffer(rewriter, rankOp.getTensor(), options, state);
857  if (failed(v))
858  return failure();
859  replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
860  *v);
861  return success();
862  }
863 };
864 
865 /// Bufferization of tensor.reshape. Replace with memref.reshape.
866 struct ReshapeOpInterface
867  : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
868  tensor::ReshapeOp> {
869  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
870  const AnalysisState &state) const {
871  // Depending on the layout map, the source buffer may have to be copied.
872  auto reshapeOp = cast<tensor::ReshapeOp>(op);
873  return opOperand == reshapeOp.getShapeMutable();
874  }
875 
876  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
877  const AnalysisState &state) const {
878  return false;
879  }
880 
881  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
882  const AnalysisState &state) const {
883  // Only the 'source' operand aliases the result.
884  auto reshapeOp = cast<tensor::ReshapeOp>(op);
885  if (reshapeOp.getSourceMutable() != opOperand)
886  return {};
887  return {{op->getOpResult(0), BufferRelation::Equivalent}};
888  }
889 
890  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
892  BufferizationState &state) const {
893  auto reshapeOp = cast<tensor::ReshapeOp>(op);
894  FailureOr<Value> srcBuffer =
895  getBuffer(rewriter, reshapeOp.getSource(), options, state);
896  FailureOr<Value> shapeBuffer =
897  getBuffer(rewriter, reshapeOp.getShape(), options, state);
898  if (failed(srcBuffer) || failed(shapeBuffer))
899  return failure();
900  auto maybeResultMemRefType =
901  bufferization::getBufferType(reshapeOp.getResult(), options, state);
902  if (failed(maybeResultMemRefType))
903  return failure();
904 
905  // memref.reshape requires the source buffer to have an identity layout.
906  // If the source memref does not have an identity layout, copy the source
907  // into a new buffer with an identity layout.
908  auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
909  if (srcType && !srcType.getLayout().isIdentity()) {
910  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
911  rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
912  if (failed(tensorAlloc))
913  return failure();
914  auto memrefType = MemRefType::get(
915  srcType.getShape(), srcType.getElementType(), AffineMap(),
916  cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
917  srcBuffer = rewriter
918  .create<bufferization::ToBufferOp>(
919  op->getLoc(), memrefType, *tensorAlloc)
920  .getResult();
921  }
922 
923  replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
924  rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
925  return success();
926  }
927 
928  FailureOr<BaseMemRefType>
930  const BufferizationState &state,
931  SmallVector<Value> &invocationStack) const {
932  auto reshapeOp = cast<tensor::ReshapeOp>(op);
933  assert(value == reshapeOp.getResult() && "unexpected value provided");
934  auto maybeSourceBufferType = bufferization::getBufferType(
935  reshapeOp.getSource(), options, state, invocationStack);
936  if (failed(maybeSourceBufferType))
937  return failure();
939  reshapeOp.getResult().getType(),
940  cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
941  }
942 };
943 
944 /// Analysis of ParallelInsertSliceOp.
945 struct ParallelInsertSliceOpInterface
946  : public BufferizableOpInterface::ExternalModel<
947  ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
948  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
949  const AnalysisState &state) const {
950  return {};
951  }
952 
953  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
954  const AnalysisState &state) const {
955  return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
956  }
957 
958  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
959  const AnalysisState &state) const {
960  auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
961  return opOperand == parallelInsertSliceOp.getDestMutable();
962  }
963 
964  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
966  BufferizationState &state) const {
967  OpBuilder::InsertionGuard g(rewriter);
968  auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
969  ParallelCombiningOpInterface parallelCombiningParent =
970  parallelInsertSliceOp.getParallelCombiningParent();
971 
972  // Bufferize the op outside of the parallel combining terminator.
973  rewriter.setInsertionPoint(parallelCombiningParent);
974 
975  // Get source and destination buffers.
976  FailureOr<Value> destBuffer =
977  getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
978  if (failed(destBuffer))
979  return failure();
980  FailureOr<Value> srcBuffer =
981  getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
982  if (failed(srcBuffer))
983  return failure();
984 
985  // Take a subview of the destination buffer.
986  auto destBufferType = cast<MemRefType>(destBuffer->getType());
987  MemRefType subviewMemRefType =
988  memref::SubViewOp::inferRankReducedResultType(
989  parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
990  parallelInsertSliceOp.getMixedOffsets(),
991  parallelInsertSliceOp.getMixedSizes(),
992  parallelInsertSliceOp.getMixedStrides());
993  Value subview = rewriter.create<memref::SubViewOp>(
994  parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
995  parallelInsertSliceOp.getMixedOffsets(),
996  parallelInsertSliceOp.getMixedSizes(),
997  parallelInsertSliceOp.getMixedStrides());
998 
999  // This memcpy will fold away if everything bufferizes in-place.
1000  if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
1001  *srcBuffer, subview)))
1002  return failure();
1003 
1004  // In case the source was allocated in the same block, make sure that the
1005  // deallocation op (if any) appears after the memcpy. By default, deallocs
1006  // are placed before the terminator, but this does not work for ForallOp
1007  // because the terminator does more than just yielding a value.
1008  //
1009  // Note: This is not a problem for the destination buffer because these are
1010  // assumed to always bufferize in-place.
1011  for (Operation *user : srcBuffer->getUsers()) {
1012  if (hasEffect<MemoryEffects::Free>(user)) {
1013  if (user->getBlock() == parallelCombiningParent->getBlock())
1014  rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
1015  break;
1016  }
1017  }
1018 
1019  // Delete the op.
1020  rewriter.eraseOp(op);
1021  return success();
1022  }
1023 
1024  /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1025  /// shouldn't create copy to resolve conflict.
1026  LogicalResult
1027  resolveConflicts(Operation *op, RewriterBase &rewriter,
1028  const AnalysisState &analysisState,
1029  const BufferizationState &bufferizationState) const {
1030  return success();
1031  }
1032 };
1033 
1034 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1035 /// with a linalg.map. Similar to tensor.generate.
1036 struct SplatOpInterface
1037  : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1038  tensor::SplatOp> {
1039 
1040  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1041 
1042  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1044  BufferizationState &state) const {
1045  OpBuilder::InsertionGuard g(rewriter);
1046  auto splatOp = cast<tensor::SplatOp>(op);
1047 
1048  // Allocate memory.
1049  Location loc = op->getLoc();
1050  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1051  rewriter, loc, splatOp.getResult(), options, state,
1052  /*copy=*/false);
1053  if (failed(tensorAlloc))
1054  return failure();
1055 
1056  // Create linalg::MapOp.
1057  auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1058 
1059  // TODO: Implement memory space for this op.
1060  if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1061  return op->emitError("memory space not implemented yet");
1062 
1063  auto linalgOp =
1064  rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
1065  /*init=*/*tensorAlloc);
1066  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1067 
1068  // Create linalg::IndexOps.
1069  rewriter.setInsertionPointToStart(&linalgBody);
1070  rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
1071  rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1072 
1073  return success();
1074  }
1075 };
1076 
1077 /// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1078 /// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1079 /// on subviews instead of memref.store.
1080 struct ConcatOpInterface
1081  : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1082  tensor::ConcatOp> {
1083 
1084  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1085 
1086  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1087  const AnalysisState &state) const {
1088  return false;
1089  }
1090 
1091  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1092  const AnalysisState &state) const {
1093  return true;
1094  }
1095 
1096  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1097  const AnalysisState &state) const {
1098  return {};
1099  }
1100 
1101  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1103  BufferizationState &state) const {
1104  OpBuilder::InsertionGuard g(rewriter);
1105  auto concatOp = cast<tensor::ConcatOp>(op);
1106 
1107  // Allocate memory.
1108  Location loc = op->getLoc();
1109  FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1110  rewriter, loc, concatOp.getResult(), options, state,
1111  /*copy=*/false);
1112  if (failed(tensorAlloc))
1113  return failure();
1114  auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1115 
1116  // TODO: Implement memory space for this op.
1117  if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1118  return op->emitError("memory space not implemented yet");
1119 
1120  MemRefLayoutAttrInterface layout;
1121  MemRefType memrefType =
1122  MemRefType::get(concatOp.getResultType().getShape(),
1123  concatOp.getResultType().getElementType(), layout);
1124  Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
1125  op->getLoc(), memrefType, *tensorAlloc);
1126 
1127  // Extract the dimension for the concat op
1128  uint64_t concatDim = concatOp.getDim();
1129  bool dynamicConcatDim = false;
1130 
1131  SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1132  rewriter.getIndexAttr(0));
1133  SmallVector<OpFoldResult> strides(tensorType.getRank(),
1134  rewriter.getIndexAttr(1));
1136 
1137  for (const auto &[dimIdx, dimSize] :
1138  llvm::enumerate(tensorType.getShape())) {
1139  if (dimSize == ShapedType::kDynamic) {
1140  auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
1141  sizes.push_back(dimOp.getResult());
1142  if (dimIdx == concatDim)
1143  dynamicConcatDim = true;
1144  } else {
1145  sizes.push_back(rewriter.getIndexAttr(dimSize));
1146  }
1147  }
1148 
1149  int64_t concatDimOffset = 0;
1150  std::optional<Value> dynamicOffset;
1151  std::optional<Value> dynamicSize;
1152  if (dynamicConcatDim) {
1153  // One or more operands have dynamic size, so we must accumulate the
1154  // offset with arith ops.
1155  dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1156  }
1157 
1158  for (auto operand : concatOp.getInputs()) {
1159  // Get the buffer for the operand.
1160  FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
1161  if (failed(srcBuffer))
1162  return failure();
1163 
1164  // Each operand may have a different size along the concat dimension,
1165  // so the offset on that axis must accumulate through the loop, and the
1166  // size must change to the size of the current operand.
1167  auto operandTensorType = cast<RankedTensorType>(operand.getType());
1168  int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1169 
1170  if (dynamicConcatDim) {
1171  offsets[concatDim] = dynamicOffset.value();
1172  dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1173  .getResult();
1174  sizes[concatDim] = dynamicSize.value();
1175  } else {
1176  sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1177  offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1178  }
1179 
1180  // Create a subview of the destination buffer.
1181  auto dstMemrefType = cast<MemRefType>(memrefType);
1182  MemRefType subviewMemRefType =
1183  memref::SubViewOp::inferRankReducedResultType(
1184  operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1185  strides);
1186  Value subview = rewriter.create<memref::SubViewOp>(
1187  loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1188 
1189  // Copy the source buffer into the destination subview.
1190  if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1191  return failure();
1192 
1193  if (dynamicConcatDim) {
1194  dynamicOffset = rewriter.create<arith::AddIOp>(
1195  loc, dynamicOffset.value(), dynamicSize.value());
1196  } else {
1197  concatDimOffset += operandConcatDimSize;
1198  }
1199  }
1200 
1201  replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1202  return success();
1203  }
1204 };
1205 
1206 } // namespace
1207 } // namespace tensor
1208 } // namespace mlir
1209 
1211  DialectRegistry &registry) {
1212  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1213  CastOp::attachInterface<CastOpInterface>(*ctx);
1214  CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1215  ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1216  DimOp::attachInterface<DimOpInterface>(*ctx);
1217  EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1218  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1219  ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1220  ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1221  FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1222  GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1223  InsertOp::attachInterface<InsertOpInterface>(*ctx);
1224  InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1225  PadOp::attachInterface<PadOpInterface>(*ctx);
1226  ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1227  *ctx);
1228  RankOp::attachInterface<RankOpInterface>(*ctx);
1229  ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1230  SplatOp::attachInterface<SplatOpInterface>(*ctx);
1231 
1232  // Load additional dialects of which ops may get created.
1233  ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1234  });
1235 
1236  // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1237  // they are registered.
1239 }
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getNumArguments()
Definition: Region.h:123
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
BufferizationState provides information about the state of the IR during the bufferization process.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...