1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
19 using namespace mlir;
20 using namespace mlir::bufferization;
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
31  // Element type, rank and memory space must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getMemorySpace() != destType.getMemorySpace())
35  return failure();
36  if (srcType.getRank() != destType.getRank())
37  return failure();
39  // In case the affine maps are different, we may need to use a copy if we go
40  // from dynamic to static offset or stride (the canonicalization cannot know
41  // at this point that it is really cast compatible).
42  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43  int64_t sourceOffset, targetOffset;
44  SmallVector<int64_t, 4> sourceStrides, targetStrides;
45  if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46  failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47  return false;
48  auto dynamicToStatic = [](int64_t a, int64_t b) {
49  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50  };
51  if (dynamicToStatic(sourceOffset, targetOffset))
52  return false;
53  for (auto it : zip(sourceStrides, targetStrides))
54  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55  return false;
56  return true;
57  };
59  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60  // ensure that we only generate casts that always succeed at runtime, we check
61  // a fix extra conditions in `isGuaranteedCastCompatible`.
62  if (memref::CastOp::areCastCompatible(srcType, destType) &&
63  isGuaranteedCastCompatible(srcType, destType)) {
64  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65  return casted;
66  }
68  auto loc = value.getLoc();
69  SmallVector<Value, 4> dynamicOperands;
70  for (int i = 0; i < destType.getRank(); ++i) {
71  if (destType.getShape()[i] != ShapedType::kDynamic)
72  continue;
73  Value size = b.create<memref::DimOp>(loc, value, i);
74  dynamicOperands.push_back(size);
75  }
77  FailureOr<Value> copy =
78  options.createAlloc(b, loc, destType, dynamicOperands);
79  if (failed(copy))
80  return failure();
81  if (failed(options.createMemCpy(b, loc, value, *copy)))
82  return failure();
83  return copy;
84 }
86 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87 /// to_memref op are different, a memref.cast is needed.
89  RewriterBase &rewriter, ToMemrefOp toMemref,
91  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92  if (!memrefToTensor)
93  return failure();
95  Type srcType = memrefToTensor.getMemref().getType();
96  Type destType = toMemref.getType();
98  // Directly rewrite if the type did not change.
99  if (srcType == destType) {
100  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101  return success();
102  }
104  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
108  // Ranked memref -> Ranked memref cast.
109  if (rankedSrcType && rankedDestType) {
110  FailureOr<Value> replacement = castOrReallocMemRefValue(
111  rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112  if (failed(replacement))
113  return failure();
115  rewriter.replaceOp(toMemref, *replacement);
116  return success();
117  }
119  // Unranked memref -> Ranked memref cast: May require a copy.
120  // TODO: Not implemented at the moment.
121  if (unrankedSrcType && rankedDestType)
122  return failure();
124  // Unranked memref -> unranked memref cast
125  // Ranked memref -> unranked memref cast: No copy needed.
126  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127  "expected that types are cast compatible");
128  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
129  memrefToTensor.getMemref());
130  return success();
131 }
134  OpBuilder &b, Location loc, Value shapedValue,
135  SmallVector<Value> &dynamicDims) {
136  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138  if (shapedType.isDynamicDim(i)) {
139  if (llvm::isa<MemRefType>(shapedType)) {
140  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141  } else {
142  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144  }
145  }
146  }
147 }
149 //===----------------------------------------------------------------------===//
150 // AllocTensorOp
151 //===----------------------------------------------------------------------===//
153 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154  const BufferizationOptions &options) {
155  OpBuilder::InsertionGuard g(rewriter);
156  Location loc = getLoc();
158  // Nothing to do for dead AllocTensorOps.
159  if (getOperation()->getUses().empty()) {
160  rewriter.eraseOp(getOperation());
161  return success();
162  }
164  // Get "copy" buffer.
165  Value copyBuffer;
166  if (getCopy()) {
167  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
173  // Create memory allocation.
174  auto allocType = bufferization::getBufferType(getResult(), options);
175  if (failed(allocType))
176  return failure();
177  SmallVector<Value> dynamicDims = getDynamicSizes();
178  if (getCopy()) {
179  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181  }
182  FailureOr<Value> alloc = options.createAlloc(
183  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184  if (failed(alloc))
185  return failure();
187  // Create memory copy (if any).
188  if (getCopy()) {
189  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190  return failure();
191  }
193  // Replace op.
194  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
196  return success();
197 }
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200  const AnalysisState &state) {
201  // AllocTensorOps do not write unless they have a `copy` value.
202  return static_cast<bool>(getCopy());
203 }
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206  const AnalysisState &state) {
207  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208  "expected copy operand");
209  return true;
210 }
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213  const AnalysisState &state) {
214  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215  "expected copy operand");
216  return false;
217 }
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220  const AnalysisState &state) {
221  // This is a new allocation. It does not alias with any other buffer.
222  return {};
223 }
225 FailureOr<BaseMemRefType>
227  SmallVector<Value> &invocationStack) {
228  assert(value == getResult() && "invalid value");
230  // Compute memory space of this allocation.
231  Attribute memorySpace;
232  if (getMemorySpace().has_value()) {
233  memorySpace = *getMemorySpace();
234  } else if (getCopy()) {
235  auto copyBufferType =
236  bufferization::getBufferType(getCopy(), options, invocationStack);
237  if (failed(copyBufferType))
238  return failure();
239  memorySpace = copyBufferType->getMemorySpace();
240  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241  memorySpace = *ms;
242  } else {
243  return getOperation()->emitError("could not infer memory space");
244  }
246  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247 }
249 LogicalResult AllocTensorOp::verify() {
250  if (getCopy() && !getDynamicSizes().empty())
251  return emitError("dynamic sizes not needed when copying a tensor");
252  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
253  return emitError("expected ")
254  << getType().getNumDynamicDims() << " dynamic sizes";
255  if (getCopy() && getCopy().getType() != getType())
256  return emitError("expected that `copy` and return type match");
257  return success();
258 }
260 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
261  RankedTensorType type, ValueRange dynamicSizes) {
262  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
263  /*size_hint=*/Value(),
264  /*memory_space=*/IntegerAttr());
265 }
267 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
268  RankedTensorType type, ValueRange dynamicSizes,
269  Value copy) {
270  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
271  /*memory_space=*/IntegerAttr());
272 }
274 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
275  TensorType type, ValueRange dynamicSizes, Value copy,
276  IntegerAttr memorySpace) {
277  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
278  memorySpace);
279 }
281 namespace {
282 /// Change the type of the result of a `bufferization.alloc_tensor` by making
283 /// the result type statically sized along dimension that in the original
284 /// operation where defined as dynamic, but the size was defined using a
285 /// `constant` op. For example:
286 ///
287 /// %c5 = arith.constant 5: index
288 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
289 ///
290 /// to
291 ///
292 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
293 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
296  LogicalResult matchAndRewrite(AllocTensorOp op,
297  PatternRewriter &rewriter) const override {
298  if (op.getCopy())
299  return failure();
300  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
301  SmallVector<Value> newDynamicSizes;
302  unsigned int dynValCounter = 0;
303  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
304  if (!op.isDynamicDim(i))
305  continue;
306  Value value = op.getDynamicSizes()[dynValCounter++];
307  APInt intVal;
308  if (matchPattern(value, m_ConstantInt(&intVal))) {
309  int64_t dim = intVal.getSExtValue();
310  if (dim >= 0)
311  newShape[i] = intVal.getSExtValue();
312  else
313  newDynamicSizes.push_back(value);
314  } else {
315  newDynamicSizes.push_back(value);
316  }
317  }
318  RankedTensorType newType = RankedTensorType::get(
319  newShape, op.getType().getElementType(), op.getType().getEncoding());
320  if (newType == op.getType())
321  return failure();
322  auto newOp = rewriter.create<AllocTensorOp>(
323  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325  return success();
326  }
327 };
329 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
332  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333  PatternRewriter &rewriter) const override {
334  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
335  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336  if (!allocTensorOp || !maybeConstantIndex)
337  return failure();
338  if (*maybeConstantIndex < 0 ||
339  *maybeConstantIndex >= allocTensorOp.getType().getRank())
340  return failure();
341  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
342  return failure();
343  rewriter.replaceOp(
344  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
345  return success();
346  }
347 };
348 } // namespace
350 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
351  MLIRContext *ctx) {
352  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
353 }
356  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
357  auto shapes = llvm::to_vector<4>(
358  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
359  [&](int64_t dim) -> OpFoldResult {
360  if (isDynamicDim(dim))
361  return getDynamicSize(builder, dim);
362  return builder.getIndexAttr(getStaticSize(dim));
363  }));
364  reifiedReturnShapes.emplace_back(std::move(shapes));
365  return success();
366 }
368 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
369  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
370  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
371  parser.parseRParen())
372  return failure();
373  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
374  OpAsmParser::UnresolvedOperand copyOperand;
375  if (copyKeyword.succeeded())
376  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
377  parser.parseRParen())
378  return failure();
379  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
380  OpAsmParser::UnresolvedOperand sizeHintOperand;
381  if (sizeHintKeyword.succeeded())
382  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
383  return failure();
384  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
385  return failure();
387  TensorType type;
388  if (parser.parseCustomTypeWithFallback(type))
389  return failure();
390  result.addTypes(type);
392  Type indexType = parser.getBuilder().getIndexType();
393  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
394  return failure();
395  if (copyKeyword.succeeded())
396  if (parser.resolveOperand(copyOperand, type, result.operands))
397  return failure();
398  if (sizeHintKeyword.succeeded())
399  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
400  return failure();
401  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
403  {static_cast<int32_t>(dynamicSizesOperands.size()),
404  static_cast<int32_t>(copyKeyword.succeeded()),
405  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
406  return success();
407 }
410  p << "(" << getDynamicSizes() << ")";
411  if (getCopy())
412  p << " copy(" << getCopy() << ")";
413  if (getSizeHint())
414  p << " size_hint=" << getSizeHint();
415  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
416  AllocTensorOp::getOperandSegmentSizeAttr()});
417  p << " : ";
418  auto type = getResult().getType();
419  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
420  p.printStrippedAttrOrType(validType);
421  else
422  p << type;
423 }
425 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
426  assert(isDynamicDim(idx) && "expected dynamic dim");
427  if (getCopy())
428  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
429  return getOperand(getIndexOfDynamicSize(idx));
430 }
432 //===----------------------------------------------------------------------===//
433 // CloneOp
434 //===----------------------------------------------------------------------===//
436 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
437  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
438 }
440 namespace {
442 /// Merge the clone and its source (by converting the clone to a cast) when
443 /// possible.
444 struct SimplifyClones : public OpRewritePattern<CloneOp> {
447  LogicalResult matchAndRewrite(CloneOp cloneOp,
448  PatternRewriter &rewriter) const override {
449  if (cloneOp.use_empty()) {
450  rewriter.eraseOp(cloneOp);
451  return success();
452  }
454  Value source = cloneOp.getInput();
455  if (source.getType() != cloneOp.getType() &&
456  !memref::CastOp::areCastCompatible({source.getType()},
457  {cloneOp.getType()}))
458  return failure();
460  // Aims to find the dealloc op for the canonical source
461  // which otherwise could prevent removal of unnecessary allocs.
462  Value canonicalSource = source;
463  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
464  canonicalSource.getDefiningOp()))
465  canonicalSource = iface.getViewSource();
467  std::optional<Operation *> maybeCloneDeallocOp =
468  memref::findDealloc(cloneOp.getOutput());
469  // Skip if either of them has > 1 deallocate operations.
470  if (!maybeCloneDeallocOp.has_value())
471  return failure();
472  std::optional<Operation *> maybeSourceDeallocOp =
473  memref::findDealloc(canonicalSource);
474  if (!maybeSourceDeallocOp.has_value())
475  return failure();
476  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
477  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
479  // If both are deallocated in the same block, their in-block lifetimes
480  // might not fully overlap, so we cannot decide which one to drop.
481  if (cloneDeallocOp && sourceDeallocOp &&
482  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
483  return failure();
485  Block *currentBlock = cloneOp->getBlock();
486  Operation *redundantDealloc = nullptr;
487  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
488  redundantDealloc = cloneDeallocOp;
489  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
490  redundantDealloc = sourceDeallocOp;
491  }
493  if (!redundantDealloc)
494  return failure();
496  // Safety check that there are no other deallocations inbetween
497  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
498  // of source before the uses of the clone. With alias information, we could
499  // restrict this to only fail of the dealloc's operand is an alias
500  // of the source.
501  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
502  pos = pos->getNextNode()) {
503  // Bail if we run out of operations while looking for a deallocation op.
504  if (!pos)
505  return failure();
506  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
507  if (!effectInterface)
508  continue;
509  if (effectInterface.hasEffect<MemoryEffects::Free>())
510  return failure();
511  }
513  if (source.getType() != cloneOp.getType())
514  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
515  cloneOp.getType(), source);
516  rewriter.replaceOp(cloneOp, source);
517  rewriter.eraseOp(redundantDealloc);
518  return success();
519  }
520 };
522 } // namespace
524 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
525  MLIRContext *context) {
526  results.add<SimplifyClones>(context);
527 }
529 //===----------------------------------------------------------------------===//
530 // DeallocTensorOp
531 //===----------------------------------------------------------------------===//
533 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
534  const BufferizationOptions &options) {
535  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
536  if (failed(buffer))
537  return failure();
538  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
539  rewriter.eraseOp(getOperation());
540  return success();
541 }
543 //===----------------------------------------------------------------------===//
544 // MaterializeInDestinationOp
545 //===----------------------------------------------------------------------===//
547 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
548  OpOperand &opOperand, const AnalysisState &state) {
549  return opOperand == getSourceMutable();
550 }
552 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
553  OpOperand &opOperand, const AnalysisState &state) {
554  if (opOperand == getDestMutable()) {
555  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
556  return true;
557  }
558  return false;
559 }
561 bool MaterializeInDestinationOp::mustBufferizeInPlace(
562  OpOperand &opOperand, const AnalysisState &state) {
563  // The source is only read and not written, so it always bufferizes in-place
564  // by default. The destination is written and is forced to bufferize in-place
565  // (if it is a tensor).
566  return true;
567 }
570 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
571  const AnalysisState &state) {
572  if (opOperand == getDestMutable()) {
573  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
574  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
575  }
576  return {};
577 }
579 LogicalResult
580 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
581  const BufferizationOptions &options) {
582  bool tensorDest = isa<TensorType>(getDest().getType());
583  Value buffer;
584  if (tensorDest) {
585  FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
586  if (failed(maybeBuffer))
587  return failure();
588  buffer = *maybeBuffer;
589  } else {
590  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
591  buffer = getDest();
592  }
593  auto srcBuffer = getBuffer(rewriter, getSource(), options);
594  if (failed(srcBuffer))
595  return failure();
596  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
597  return failure();
598  replaceOpWithBufferizedValues(rewriter, getOperation(),
599  tensorDest ? ValueRange(buffer) : ValueRange());
600  return success();
601 }
603 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
604  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
605  // As elements are copied from the "source" buffer to the "dest" buffer,
606  // already copied elements are not read a second time.
607  return true;
608 }
611  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
612  if (getOperation()->getNumResults() == 1) {
613  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
614  reifiedReturnShapes.resize(1,
615  SmallVector<OpFoldResult>(getType().getRank()));
616  reifiedReturnShapes[0] =
617  tensor::getMixedSizes(builder, getLoc(), getDest());
618  }
619  return success();
620 }
622 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
623  Location loc) {
624  if (isa<TensorType>(getDest().getType())) {
625  // The subset is the entire destination tensor.
626  return getDest();
627  }
629  // The "restrict" attribute is transferred from this op to the newly created
630  // to_tensor op. If this op does not the "restrict" attribute, the subset
631  // extraction cannot be built because there is no guarantee that there is no
632  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
633  if (!getRestrict())
634  return {};
636  // Build a bufferization.to_tensor op.
637  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
638  assert(getRestrict() &&
639  "expected that ops with memrefs dest have 'restrict'");
640  setRestrict(false);
641  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
642  getWritable());
643 }
645 bool MaterializeInDestinationOp::isEquivalentSubset(
646  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
647  return equivalenceFn(getDest(), candidate);
648 }
651 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
652  return {getDest()};
653 }
655 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
656  return getOperation()->getOpOperand(0) /*source*/;
657 }
659 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
660  SubsetOpInterface subsetOp,
661  function_ref<bool(Value, Value)> equivalenceFn) {
662  return false;
663 }
665 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
666  SubsetOpInterface subsetOp,
667  function_ref<bool(Value, Value)> equivalenceFn) {
668  return false;
669 }
671 LogicalResult MaterializeInDestinationOp::verify() {
672  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
673  return emitOpError("'dest' must be a tensor or a memref");
674  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
675  if (getOperation()->getNumResults() != 1)
676  return emitOpError("tensor 'dest' implies exactly one tensor result");
677  if (destType != getResult().getType())
678  return emitOpError("result and 'dest' types must match");
679  }
680  if (isa<BaseMemRefType>(getDest().getType()) &&
681  getOperation()->getNumResults() != 0)
682  return emitOpError("memref 'dest' implies zero results");
683  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
684  return emitOpError("'restrict' is valid only for memref destinations");
685  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
686  return emitOpError("'writable' must be specified if and only if the "
687  "destination is of memref type");
688  TensorType srcType = getSource().getType();
689  ShapedType destType = cast<ShapedType>(getDest().getType());
690  if (srcType.hasRank() != destType.hasRank())
691  return emitOpError("source/destination shapes are incompatible");
692  if (srcType.hasRank()) {
693  if (srcType.getRank() != destType.getRank())
694  return emitOpError("rank mismatch between source and destination shape");
695  for (auto [src, dest] :
696  llvm::zip(srcType.getShape(), destType.getShape())) {
697  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
698  // Cannot verify dynamic dimension size. Assume that that they match at
699  // runtime.
700  continue;
701  }
702  if (src != dest)
703  return emitOpError("source/destination shapes are incompatible");
704  }
705  }
706  return success();
707 }
709 void MaterializeInDestinationOp::build(OpBuilder &builder,
710  OperationState &state, Value source,
711  Value dest) {
712  auto destTensorType = dyn_cast<TensorType>(dest.getType());
713  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
714  source, dest);
715 }
717 bool MaterializeInDestinationOp::isWritable(Value value,
718  const AnalysisState &state) {
719  return isa<TensorType>(getDest().getType()) ? true : getWritable();
720 }
722 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
723  return getDestMutable();
724 }
726 void MaterializeInDestinationOp::getEffects(
728  &effects) {
729  if (isa<BaseMemRefType>(getDest().getType()))
730  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
732 }
734 //===----------------------------------------------------------------------===//
735 // ToTensorOp
736 //===----------------------------------------------------------------------===//
738 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
739  return getWritable();
740 }
742 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
743  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
744  // Approximate alias analysis by conservatively folding only when no there
745  // is no interleaved operation.
746  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
747  toMemref->getNextNode() == this->getOperation())
748  return toMemref.getTensor();
749  return {};
750 }
752 namespace {
753 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
756  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
757  PatternRewriter &rewriter) const override {
758  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
759  if (!memrefToTensorOp)
760  return failure();
762  rewriter.replaceOpWithNewOp<memref::DimOp>(
763  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
764  return success();
765  }
766 };
767 } // namespace
769 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
770  MLIRContext *context) {
771  results.add<DimOfToTensorFolder>(context);
772 }
774 //===----------------------------------------------------------------------===//
775 // ToMemrefOp
776 //===----------------------------------------------------------------------===//
778 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
779  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
780  if (memrefToTensor.getMemref().getType() == getType())
781  return memrefToTensor.getMemref();
782  return {};
783 }
785 namespace {
787 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
788 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
791  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
792  PatternRewriter &rewriter) const final {
793  auto tensorCastOperand =
794  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
795  if (!tensorCastOperand)
796  return failure();
797  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
798  tensorCastOperand.getOperand().getType());
799  if (!srcTensorType)
800  return failure();
801  auto memrefType = MemRefType::get(srcTensorType.getShape(),
802  srcTensorType.getElementType());
803  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
804  tensorCastOperand.getOperand());
805  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
806  memref);
807  return success();
808  }
809 };
811 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
812 /// cast if necessary.
813 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
816  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
817  PatternRewriter &rewriter) const final {
819  options.bufferAlignment = 0;
820  return foldToMemrefToTensorPair(rewriter, toMemref, options);
821  }
822 };
824 /// Fold a load on a to_memref operation into an tensor.extract on the
825 /// corresponding tensor.
826 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
829  LogicalResult matchAndRewrite(memref::LoadOp load,
830  PatternRewriter &rewriter) const override {
831  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
832  if (!toMemref)
833  return failure();
835  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
836  load.getIndices());
837  return success();
838  }
839 };
841 /// Fold dim of a to_memref into the dim of the tensor.
842 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
845  LogicalResult matchAndRewrite(memref::DimOp dimOp,
846  PatternRewriter &rewriter) const override {
847  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
848  if (!castOp)
849  return failure();
850  Value newSource = castOp.getOperand();
851  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
852  dimOp.getIndex());
853  return success();
854  }
855 };
857 } // namespace
859 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
860  MLIRContext *context) {
861  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
862  ToMemrefToTensorFolding>(context);
863 }
865 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
866  const BufferizationOptions &options) {
867  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
868  (void)foldToMemrefToTensorPair(rewriter, *this, options);
869  // Note: The return value of `bufferize` indicates whether there was an error
870  // or not. (And not whether the pattern matched or not.)
871  return success();
872 }
874 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
875  Value alloc) {
876  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
877  .getOperation();
878 }
880 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
881  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
882 }
884 //===----------------------------------------------------------------------===//
885 // DeallocOp
886 //===----------------------------------------------------------------------===//
888 LogicalResult DeallocOp::inferReturnTypes(
889  MLIRContext *context, std::optional<::mlir::Location> location,
890  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
891  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
892  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
893  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
894  IntegerType::get(context, 1));
895  return success();
896 }
898 LogicalResult DeallocOp::verify() {
899  if (getMemrefs().size() != getConditions().size())
900  return emitOpError(
901  "must have the same number of conditions as memrefs to deallocate");
902  if (getRetained().size() != getUpdatedConditions().size())
903  return emitOpError("must have the same number of updated conditions "
904  "(results) as retained operands");
905  return success();
906 }
908 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
909  ValueRange memrefs,
910  ValueRange conditions,
911  PatternRewriter &rewriter) {
912  if (deallocOp.getMemrefs() == memrefs &&
913  deallocOp.getConditions() == conditions)
914  return failure();
916  rewriter.modifyOpInPlace(deallocOp, [&]() {
917  deallocOp.getMemrefsMutable().assign(memrefs);
918  deallocOp.getConditionsMutable().assign(conditions);
919  });
920  return success();
921 }
923 namespace {
925 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
926 /// make sure the corresponding condition value is updated accordingly since
927 /// their two conditions might not cover the same set of cases. In that case, we
928 /// have to combine them (by computing the disjunction of them).
929 /// Example:
930 /// ```mlir
931 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
932 /// ```
933 /// is canonicalized to
934 /// ```mlir
935 /// %0 = arith.ori %arg1, %arg2 : i1
936 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
937 /// ```
938 struct DeallocRemoveDuplicateDeallocMemrefs
939  : public OpRewritePattern<DeallocOp> {
942  LogicalResult matchAndRewrite(DeallocOp deallocOp,
943  PatternRewriter &rewriter) const override {
944  // Unique memrefs to be deallocated.
945  DenseMap<Value, unsigned> memrefToCondition;
946  SmallVector<Value> newMemrefs, newConditions;
947  for (auto [i, memref, cond] :
948  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
949  if (memrefToCondition.count(memref)) {
950  // If the dealloc conditions don't match, we need to make sure that the
951  // dealloc happens on the union of cases.
952  Value &newCond = newConditions[memrefToCondition[memref]];
953  if (newCond != cond)
954  newCond =
955  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
956  } else {
957  memrefToCondition.insert({memref, newConditions.size()});
958  newMemrefs.push_back(memref);
959  newConditions.push_back(cond);
960  }
961  }
963  // Return failure if we don't change anything such that we don't run into an
964  // infinite loop of pattern applications.
965  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
966  rewriter);
967  }
968 };
970 /// Remove duplicate values in the list of retained memrefs. We need to make
971 /// sure the corresponding result condition value is replaced properly.
972 /// Example:
973 /// ```mlir
974 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
975 /// ```
976 /// is canonicalized to
977 /// ```mlir
978 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
979 /// ```
980 struct DeallocRemoveDuplicateRetainedMemrefs
981  : public OpRewritePattern<DeallocOp> {
984  LogicalResult matchAndRewrite(DeallocOp deallocOp,
985  PatternRewriter &rewriter) const override {
986  // Unique retained values
988  SmallVector<Value> newRetained;
989  SmallVector<unsigned> resultReplacementIdx;
990  unsigned i = 0;
991  for (auto retained : deallocOp.getRetained()) {
992  if (seen.count(retained)) {
993  resultReplacementIdx.push_back(seen[retained]);
994  continue;
995  }
997  seen[retained] = i;
998  newRetained.push_back(retained);
999  resultReplacementIdx.push_back(i++);
1000  }
1002  // Return failure if we don't change anything such that we don't run into an
1003  // infinite loop of pattern applications.
1004  if (newRetained.size() == deallocOp.getRetained().size())
1005  return failure();
1007  // We need to create a new op because the number of results is always the
1008  // same as the number of condition operands.
1009  auto newDeallocOp =
1010  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1011  deallocOp.getConditions(), newRetained);
1012  SmallVector<Value> replacements(
1013  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1014  return newDeallocOp.getUpdatedConditions()[idx];
1015  }));
1016  rewriter.replaceOp(deallocOp, replacements);
1017  return success();
1018  }
1019 };
1021 /// Erase deallocation operations where the variadic list of memrefs to
1022 /// deallocate is empty. Example:
1023 /// ```mlir
1024 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1025 /// ```
1026 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1029  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1030  PatternRewriter &rewriter) const override {
1031  if (deallocOp.getMemrefs().empty()) {
1032  Value constFalse = rewriter.create<arith::ConstantOp>(
1033  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1034  rewriter.replaceOp(
1035  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1036  constFalse));
1037  return success();
1038  }
1039  return failure();
1040  }
1041 };
1043 /// Removes memrefs from the deallocation list if their associated condition is
1044 /// always 'false'.
1045 ///
1046 /// Example:
1047 /// ```
1048 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1049 /// if (%arg2, %false)
1050 /// ```
1051 /// becomes
1052 /// ```
1053 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1054 /// ```
1055 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1058  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1059  PatternRewriter &rewriter) const override {
1060  SmallVector<Value> newMemrefs, newConditions;
1061  for (auto [memref, cond] :
1062  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1063  if (!matchPattern(cond, m_Zero())) {
1064  newMemrefs.push_back(memref);
1065  newConditions.push_back(cond);
1066  }
1067  }
1069  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1070  rewriter);
1071  }
1072 };
1074 /// The `memref.extract_strided_metadata` is often inserted to get the base
1075 /// memref if the operand is not already guaranteed to be the result of a memref
1076 /// allocation operation. This canonicalization pattern removes this extraction
1077 /// operation if the operand is now produced by an allocation operation (e.g.,
1078 /// due to other canonicalizations simplifying the IR).
1079 ///
1080 /// Example:
1081 /// ```mlir
1082 /// %alloc = memref.alloc() : memref<2xi32>
1083 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1084 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1085 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1086 /// ```
1087 /// is canonicalized to
1088 /// ```mlir
1089 /// %alloc = memref.alloc() : memref<2xi32>
1090 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1091 /// ```
1092 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1095  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1096  PatternRewriter &rewriter) const override {
1097  SmallVector<Value> newMemrefs(
1098  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1099  auto extractStridedOp =
1100  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1101  if (!extractStridedOp)
1102  return memref;
1103  Value allocMemref = extractStridedOp.getOperand();
1104  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1105  if (!allocOp)
1106  return memref;
1107  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1108  return allocMemref;
1109  return memref;
1110  }));
1112  return updateDeallocIfChanged(deallocOp, newMemrefs,
1113  deallocOp.getConditions(), rewriter);
1114  }
1115 };
1117 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1118 /// other user of the allocated value and the allocating operation can be safely
1119 /// removed. If the same value is present multiple times, this pattern relies on
1120 /// other canonicalization patterns to remove the duplicate first.
1121 ///
1122 /// Example:
1123 /// ```mlir
1124 /// %alloc = memref.alloc() : memref<2xi32>
1125 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1126 /// ```
1127 /// is canonicalized to
1128 /// ```mlir
1129 /// bufferization.dealloc (%arg0 : ...) if (%true)
1130 /// ```
1131 struct RemoveAllocDeallocPairWhenNoOtherUsers
1132  : public OpRewritePattern<DeallocOp> {
1135  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1136  PatternRewriter &rewriter) const override {
1137  SmallVector<Value> newMemrefs, newConditions;
1138  SmallVector<Operation *> toDelete;
1139  for (auto [memref, cond] :
1140  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1141  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1142  // Check that it is indeed an allocate effect, that the op has no other
1143  // side effects (which would not allow us to remove the op), and that
1144  // there are no other users.
1145  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1146  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1147  memref.hasOneUse()) {
1148  toDelete.push_back(allocOp);
1149  continue;
1150  }
1151  }
1153  newMemrefs.push_back(memref);
1154  newConditions.push_back(cond);
1155  }
1157  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1158  rewriter)))
1159  return failure();
1161  for (Operation *op : toDelete)
1162  rewriter.eraseOp(op);
1164  return success();
1165  }
1166 };
1168 } // anonymous namespace
1170 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1171  MLIRContext *context) {
1173 }
1176  RewritePatternSet &patterns, MLIRContext *context) {
1177  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1178  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1179  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1180  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1181 }
1183 //===----------------------------------------------------------------------===//
1184 // TableGen'd op method definitions
1185 //===----------------------------------------------------------------------===//
1187 #define GET_OP_CLASSES
1188 #include "mlir/Dialect/Bufferization/IR/"
