MLIR  19.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
28  MemRefType destType) {
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type, rank and memory space must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getMemorySpace() != destType.getMemorySpace())
35  return failure();
36  if (srcType.getRank() != destType.getRank())
37  return failure();
38 
39  // In case the affine maps are different, we may need to use a copy if we go
40  // from dynamic to static offset or stride (the canonicalization cannot know
41  // at this point that it is really cast compatible).
42  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43  int64_t sourceOffset, targetOffset;
44  SmallVector<int64_t, 4> sourceStrides, targetStrides;
45  if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46  failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47  return false;
48  auto dynamicToStatic = [](int64_t a, int64_t b) {
49  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50  };
51  if (dynamicToStatic(sourceOffset, targetOffset))
52  return false;
53  for (auto it : zip(sourceStrides, targetStrides))
54  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55  return false;
56  return true;
57  };
58 
59  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60  // ensure that we only generate casts that always succeed at runtime, we check
61  // a fix extra conditions in `isGuaranteedCastCompatible`.
62  if (memref::CastOp::areCastCompatible(srcType, destType) &&
63  isGuaranteedCastCompatible(srcType, destType)) {
64  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65  return casted;
66  }
67 
68  auto loc = value.getLoc();
69  SmallVector<Value, 4> dynamicOperands;
70  for (int i = 0; i < destType.getRank(); ++i) {
71  if (destType.getShape()[i] != ShapedType::kDynamic)
72  continue;
73  Value size = b.create<memref::DimOp>(loc, value, i);
74  dynamicOperands.push_back(size);
75  }
76  // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77  // BufferizableOpInterface impl of ToMemrefOp.
78  Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
79  b.create<memref::CopyOp>(loc, value, copy);
80  return copy;
81 }
82 
83 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
84 /// to_memref op are different, a memref.cast is needed.
87  ToMemrefOp toMemref) {
88  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
89  if (!memrefToTensor)
90  return failure();
91 
92  Type srcType = memrefToTensor.getMemref().getType();
93  Type destType = toMemref.getType();
94 
95  // Directly rewrite if the type did not change.
96  if (srcType == destType) {
97  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
98  return success();
99  }
100 
101  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
102  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
103  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
104 
105  // Ranked memref -> Ranked memref cast.
106  if (rankedSrcType && rankedDestType) {
108  rewriter, memrefToTensor.getMemref(), rankedDestType);
109  if (failed(replacement))
110  return failure();
111 
112  rewriter.replaceOp(toMemref, *replacement);
113  return success();
114  }
115 
116  // Unranked memref -> Ranked memref cast: May require a copy.
117  // TODO: Not implemented at the moment.
118  if (unrankedSrcType && rankedDestType)
119  return failure();
120 
121  // Unranked memref -> unranked memref cast
122  // Ranked memref -> unranked memref cast: No copy needed.
123  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
124  "expected that types are cast compatible");
125  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
126  memrefToTensor.getMemref());
127  return success();
128 }
129 
131  OpBuilder &b, Location loc, Value shapedValue,
132  SmallVector<Value> &dynamicDims) {
133  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
134  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
135  if (shapedType.isDynamicDim(i)) {
136  if (llvm::isa<MemRefType>(shapedType)) {
137  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
138  } else {
139  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
140  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
141  }
142  }
143  }
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // AllocTensorOp
148 //===----------------------------------------------------------------------===//
149 
150 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
151  const BufferizationOptions &options) {
152  OpBuilder::InsertionGuard g(rewriter);
153  Location loc = getLoc();
154 
155  // Nothing to do for dead AllocTensorOps.
156  if (getOperation()->getUses().empty()) {
157  rewriter.eraseOp(getOperation());
158  return success();
159  }
160 
161  // Get "copy" buffer.
162  Value copyBuffer;
163  if (getCopy()) {
164  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
165  if (failed(maybeCopyBuffer))
166  return failure();
167  copyBuffer = *maybeCopyBuffer;
168  }
169 
170  // Create memory allocation.
171  auto allocType = bufferization::getBufferType(getResult(), options);
172  if (failed(allocType))
173  return failure();
174  SmallVector<Value> dynamicDims = getDynamicSizes();
175  if (getCopy()) {
176  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
177  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
178  }
179  FailureOr<Value> alloc = options.createAlloc(
180  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
181  if (failed(alloc))
182  return failure();
183 
184  // Create memory copy (if any).
185  if (getCopy()) {
186  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
187  return failure();
188  }
189 
190  // Replace op.
191  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
192 
193  return success();
194 }
195 
196 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
197  const AnalysisState &state) {
198  // AllocTensorOps do not write unless they have a `copy` value.
199  return static_cast<bool>(getCopy());
200 }
201 
202 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
203  const AnalysisState &state) {
204  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
205  "expected copy operand");
206  return true;
207 }
208 
209 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
210  const AnalysisState &state) {
211  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
212  "expected copy operand");
213  return false;
214 }
215 
216 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
217  const AnalysisState &state) {
218  // This is a new allocation. It does not alias with any other buffer.
219  return {};
220 }
221 
224  SmallVector<Value> &invocationStack) {
225  assert(value == getResult() && "invalid value");
226 
227  // Compute memory space of this allocation.
228  Attribute memorySpace;
229  if (getMemorySpace().has_value()) {
230  memorySpace = *getMemorySpace();
231  } else if (getCopy()) {
232  auto copyBufferType =
233  bufferization::getBufferType(getCopy(), options, invocationStack);
234  if (failed(copyBufferType))
235  return failure();
236  memorySpace = copyBufferType->getMemorySpace();
237  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
238  memorySpace = *ms;
239  } else {
240  return getOperation()->emitError("could not infer memory space");
241  }
242 
243  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
244 }
245 
247  if (getCopy() && !getDynamicSizes().empty())
248  return emitError("dynamic sizes not needed when copying a tensor");
249  if (!getCopy() && getType().getNumDynamicDims() !=
250  static_cast<int64_t>(getDynamicSizes().size()))
251  return emitError("expected ")
252  << getType().getNumDynamicDims() << " dynamic sizes";
253  if (getCopy() && getCopy().getType() != getType())
254  return emitError("expected that `copy` and return type match");
255 
256  // For sparse tensor allocation, we require that none of its
257  // uses escapes the function boundary directly.
259  for (auto &use : getOperation()->getUses())
260  if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
261  use.getOwner()))
262  return emitError("sparse tensor allocation should not escape function");
263  }
264 
265  return success();
266 }
267 
268 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
269  RankedTensorType type, ValueRange dynamicSizes) {
270  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
271  /*size_hint=*/Value(),
272  /*memory_space=*/IntegerAttr());
273 }
274 
275 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
276  RankedTensorType type, ValueRange dynamicSizes,
277  Value copy) {
278  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
279  /*memory_space=*/IntegerAttr());
280 }
281 
282 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
283  TensorType type, ValueRange dynamicSizes, Value copy,
284  IntegerAttr memorySpace) {
285  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
286  memorySpace);
287 }
288 
289 namespace {
290 /// Change the type of the result of a `bufferization.alloc_tensor` by making
291 /// the result type statically sized along dimension that in the original
292 /// operation where defined as dynamic, but the size was defined using a
293 /// `constant` op. For example:
294 ///
295 /// %c5 = arith.constant 5: index
296 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
297 ///
298 /// to
299 ///
300 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
301 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
303 
304  LogicalResult matchAndRewrite(AllocTensorOp op,
305  PatternRewriter &rewriter) const override {
306  if (op.getCopy())
307  return failure();
308  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
309  SmallVector<Value> newDynamicSizes;
310  unsigned int dynValCounter = 0;
311  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
312  if (!op.isDynamicDim(i))
313  continue;
314  Value value = op.getDynamicSizes()[dynValCounter++];
315  APInt intVal;
316  if (matchPattern(value, m_ConstantInt(&intVal))) {
317  int64_t dim = intVal.getSExtValue();
318  if (dim >= 0)
319  newShape[i] = intVal.getSExtValue();
320  else
321  newDynamicSizes.push_back(value);
322  } else {
323  newDynamicSizes.push_back(value);
324  }
325  }
326  RankedTensorType newType = RankedTensorType::get(
327  newShape, op.getType().getElementType(), op.getType().getEncoding());
328  if (newType == op.getType())
329  return failure();
330  auto newOp = rewriter.create<AllocTensorOp>(
331  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
332  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
333  return success();
334  }
335 };
336 
337 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
339 
340  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
341  PatternRewriter &rewriter) const override {
342  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
343  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
344  if (!allocTensorOp || !maybeConstantIndex)
345  return failure();
346  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
347  return failure();
348  rewriter.replaceOp(
349  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
350  return success();
351  }
352 };
353 } // namespace
354 
355 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
356  MLIRContext *ctx) {
357  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
358 }
359 
361  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
362  auto shapes = llvm::to_vector<4>(
363  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
364  [&](int64_t dim) -> OpFoldResult {
365  if (isDynamicDim(dim))
366  return getDynamicSize(builder, dim);
367  return builder.getIndexAttr(getStaticSize(dim));
368  }));
369  reifiedReturnShapes.emplace_back(std::move(shapes));
370  return success();
371 }
372 
374  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
375  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
376  parser.parseRParen())
377  return failure();
378  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
379  OpAsmParser::UnresolvedOperand copyOperand;
380  if (copyKeyword.succeeded())
381  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
382  parser.parseRParen())
383  return failure();
384  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
385  OpAsmParser::UnresolvedOperand sizeHintOperand;
386  if (sizeHintKeyword.succeeded())
387  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
388  return failure();
389  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
390  return failure();
391 
392  TensorType type;
393  if (parser.parseCustomTypeWithFallback(type))
394  return failure();
395  result.addTypes(type);
396 
397  Type indexType = parser.getBuilder().getIndexType();
398  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
399  return failure();
400  if (copyKeyword.succeeded())
401  if (parser.resolveOperand(copyOperand, type, result.operands))
402  return failure();
403  if (sizeHintKeyword.succeeded())
404  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
405  return failure();
406  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
408  {static_cast<int32_t>(dynamicSizesOperands.size()),
409  static_cast<int32_t>(copyKeyword.succeeded()),
410  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
411  return success();
412 }
413 
415  p << "(" << getDynamicSizes() << ")";
416  if (getCopy())
417  p << " copy(" << getCopy() << ")";
418  if (getSizeHint())
419  p << " size_hint=" << getSizeHint();
420  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
421  AllocTensorOp::getOperandSegmentSizeAttr()});
422  p << " : ";
423  auto type = getResult().getType();
424  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
425  p.printStrippedAttrOrType(validType);
426  else
427  p << type;
428 }
429 
430 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
431  assert(isDynamicDim(idx) && "expected dynamic dim");
432  if (getCopy())
433  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
434  return getOperand(getIndexOfDynamicSize(idx));
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // CloneOp
439 //===----------------------------------------------------------------------===//
440 
441 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
442  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
443 }
444 
445 namespace {
446 
447 /// Merge the clone and its source (by converting the clone to a cast) when
448 /// possible.
449 struct SimplifyClones : public OpRewritePattern<CloneOp> {
451 
452  LogicalResult matchAndRewrite(CloneOp cloneOp,
453  PatternRewriter &rewriter) const override {
454  if (cloneOp.use_empty()) {
455  rewriter.eraseOp(cloneOp);
456  return success();
457  }
458 
459  Value source = cloneOp.getInput();
460  if (source.getType() != cloneOp.getType() &&
461  !memref::CastOp::areCastCompatible({source.getType()},
462  {cloneOp.getType()}))
463  return failure();
464 
465  // Aims to find the dealloc op for the canonical source
466  // which otherwise could prevent removal of unnecessary allocs.
467  Value canonicalSource = source;
468  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
469  canonicalSource.getDefiningOp()))
470  canonicalSource = iface.getViewSource();
471 
472  std::optional<Operation *> maybeCloneDeallocOp =
473  memref::findDealloc(cloneOp.getOutput());
474  // Skip if either of them has > 1 deallocate operations.
475  if (!maybeCloneDeallocOp.has_value())
476  return failure();
477  std::optional<Operation *> maybeSourceDeallocOp =
478  memref::findDealloc(canonicalSource);
479  if (!maybeSourceDeallocOp.has_value())
480  return failure();
481  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
482  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
483 
484  // If both are deallocated in the same block, their in-block lifetimes
485  // might not fully overlap, so we cannot decide which one to drop.
486  if (cloneDeallocOp && sourceDeallocOp &&
487  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
488  return failure();
489 
490  Block *currentBlock = cloneOp->getBlock();
491  Operation *redundantDealloc = nullptr;
492  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
493  redundantDealloc = cloneDeallocOp;
494  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
495  redundantDealloc = sourceDeallocOp;
496  }
497 
498  if (!redundantDealloc)
499  return failure();
500 
501  // Safety check that there are no other deallocations inbetween
502  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
503  // of source before the uses of the clone. With alias information, we could
504  // restrict this to only fail of the dealloc's operand is an alias
505  // of the source.
506  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
507  pos = pos->getNextNode()) {
508  // Bail if we run out of operations while looking for a deallocation op.
509  if (!pos)
510  return failure();
511  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
512  if (!effectInterface)
513  continue;
514  if (effectInterface.hasEffect<MemoryEffects::Free>())
515  return failure();
516  }
517 
518  if (source.getType() != cloneOp.getType())
519  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
520  cloneOp.getType(), source);
521  rewriter.replaceOp(cloneOp, source);
522  rewriter.eraseOp(redundantDealloc);
523  return success();
524  }
525 };
526 
527 } // namespace
528 
529 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
530  MLIRContext *context) {
531  results.add<SimplifyClones>(context);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // DeallocTensorOp
536 //===----------------------------------------------------------------------===//
537 
538 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
539  const BufferizationOptions &options) {
540  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
541  if (failed(buffer))
542  return failure();
543  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
544  rewriter.eraseOp(getOperation());
545  return success();
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // MaterializeInDestinationOp
550 //===----------------------------------------------------------------------===//
551 
552 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
553  OpOperand &opOperand, const AnalysisState &state) {
554  return opOperand == getSourceMutable();
555 }
556 
557 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
558  OpOperand &opOperand, const AnalysisState &state) {
559  if (opOperand == getDestMutable()) {
560  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
561  return true;
562  }
563  return false;
564 }
565 
566 bool MaterializeInDestinationOp::mustBufferizeInPlace(
567  OpOperand &opOperand, const AnalysisState &state) {
568  // The source is only read and not written, so it always bufferizes in-place
569  // by default. The destination is written and is forced to bufferize in-place
570  // (if it is a tensor).
571  return true;
572 }
573 
575 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
576  const AnalysisState &state) {
577  if (opOperand == getDestMutable()) {
578  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
579  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
580  }
581  return {};
582 }
583 
585 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
586  const BufferizationOptions &options) {
587  bool tensorDest = isa<TensorType>(getDest().getType());
588  Value buffer;
589  if (tensorDest) {
590  FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
591  if (failed(maybeBuffer))
592  return failure();
593  buffer = *maybeBuffer;
594  } else {
595  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
596  buffer = getDest();
597  }
598  auto srcBuffer = getBuffer(rewriter, getSource(), options);
599  if (failed(srcBuffer))
600  return failure();
601  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
602  return failure();
603  replaceOpWithBufferizedValues(rewriter, getOperation(),
604  tensorDest ? ValueRange(buffer) : ValueRange());
605  return success();
606 }
607 
608 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
609  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
610  // As elements are copied from the "source" buffer to the "dest" buffer,
611  // already copied elements are not read a second time.
612  return true;
613 }
614 
616  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
617  if (getOperation()->getNumResults() == 1) {
618  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
619  reifiedReturnShapes.resize(1,
620  SmallVector<OpFoldResult>(getType().getRank()));
621  reifiedReturnShapes[0] =
622  tensor::getMixedSizes(builder, getLoc(), getDest());
623  }
624  return success();
625 }
626 
627 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
628  Location loc) {
629  if (isa<TensorType>(getDest().getType())) {
630  // The subset is the entire destination tensor.
631  return getDest();
632  }
633 
634  // The "restrict" attribute is transferred from this op to the newly created
635  // to_tensor op. If this op does not the "restrict" attribute, the subset
636  // extraction cannot be built because there is no guarantee that there is no
637  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
638  if (!getRestrict())
639  return {};
640 
641  // Build a bufferization.to_tensor op.
642  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
643  assert(getRestrict() &&
644  "expected that ops with memrefs dest have 'restrict'");
645  setRestrict(false);
646  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
647  getWritable());
648 }
649 
650 bool MaterializeInDestinationOp::isEquivalentSubset(
651  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
652  return equivalenceFn(getDest(), candidate);
653 }
654 
656 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
657  return {getDest()};
658 }
659 
660 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
661  return getOperation()->getOpOperand(0) /*source*/;
662 }
663 
664 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
665  SubsetOpInterface subsetOp,
666  function_ref<bool(Value, Value)> equivalenceFn) {
667  return false;
668 }
669 
670 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
671  SubsetOpInterface subsetOp,
672  function_ref<bool(Value, Value)> equivalenceFn) {
673  return false;
674 }
675 
677  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
678  return emitOpError("'dest' must be a tensor or a memref");
679  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
680  if (getOperation()->getNumResults() != 1)
681  return emitOpError("tensor 'dest' implies exactly one tensor result");
682  if (destType != getResult().getType())
683  return emitOpError("result and 'dest' types must match");
684  }
685  if (isa<BaseMemRefType>(getDest().getType()) &&
686  getOperation()->getNumResults() != 0)
687  return emitOpError("memref 'dest' implies zero results");
688  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
689  return emitOpError("'restrict' is valid only for memref destinations");
690  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
691  return emitOpError("'writable' must be specified if and only if the "
692  "destination is of memref type");
693  return success();
694 }
695 
696 void MaterializeInDestinationOp::build(OpBuilder &builder,
697  OperationState &state, Value source,
698  Value dest) {
699  auto destTensorType = dyn_cast<TensorType>(dest.getType());
700  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
701  source, dest);
702 }
703 
704 bool MaterializeInDestinationOp::isWritable(Value value,
705  const AnalysisState &state) {
706  return isa<TensorType>(getDest().getType()) ? true : getWritable();
707 }
708 
709 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
710  return getDestMutable();
711 }
712 
713 void MaterializeInDestinationOp::getEffects(
715  &effects) {
716  if (isa<BaseMemRefType>(getDest().getType()))
717  effects.emplace_back(MemoryEffects::Write::get(), getDest(),
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // ToTensorOp
723 //===----------------------------------------------------------------------===//
724 
725 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
726  return getWritable();
727 }
728 
729 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
730  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
731  // Approximate alias analysis by conservatively folding only when no there
732  // is no interleaved operation.
733  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
734  toMemref->getNextNode() == this->getOperation())
735  return toMemref.getTensor();
736  return {};
737 }
738 
739 namespace {
740 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
742 
743  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
744  PatternRewriter &rewriter) const override {
745  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
746  if (!memrefToTensorOp)
747  return failure();
748 
749  rewriter.replaceOpWithNewOp<memref::DimOp>(
750  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
751  return success();
752  }
753 };
754 } // namespace
755 
756 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
757  MLIRContext *context) {
758  results.add<DimOfToTensorFolder>(context);
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // ToMemrefOp
763 //===----------------------------------------------------------------------===//
764 
765 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
766  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
767  if (memrefToTensor.getMemref().getType() == getType())
768  return memrefToTensor.getMemref();
769  return {};
770 }
771 
772 namespace {
773 
774 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
775 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
777 
778  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
779  PatternRewriter &rewriter) const final {
780  auto tensorCastOperand =
781  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
782  if (!tensorCastOperand)
783  return failure();
784  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
785  tensorCastOperand.getOperand().getType());
786  if (!srcTensorType)
787  return failure();
788  auto memrefType = MemRefType::get(srcTensorType.getShape(),
789  srcTensorType.getElementType());
790  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
791  tensorCastOperand.getOperand());
792  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
793  memref);
794  return success();
795  }
796 };
797 
798 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
799 /// cast if necessary.
800 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
802 
803  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
804  PatternRewriter &rewriter) const final {
805  return foldToMemrefToTensorPair(rewriter, toMemref);
806  }
807 };
808 
809 /// Fold a load on a to_memref operation into an tensor.extract on the
810 /// corresponding tensor.
811 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
813 
814  LogicalResult matchAndRewrite(memref::LoadOp load,
815  PatternRewriter &rewriter) const override {
816  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
817  if (!toMemref)
818  return failure();
819 
820  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
821  load.getIndices());
822  return success();
823  }
824 };
825 
826 /// Fold dim of a to_memref into the dim of the tensor.
827 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
829 
830  LogicalResult matchAndRewrite(memref::DimOp dimOp,
831  PatternRewriter &rewriter) const override {
832  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
833  if (!castOp)
834  return failure();
835  Value newSource = castOp.getOperand();
836  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
837  dimOp.getIndex());
838  return success();
839  }
840 };
841 
842 } // namespace
843 
844 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
845  MLIRContext *context) {
846  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
847  ToMemrefToTensorFolding>(context);
848 }
849 
850 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
851  const BufferizationOptions &options) {
852  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
853  (void)foldToMemrefToTensorPair(rewriter, *this);
854  // Note: The return value of `bufferize` indicates whether there was an error
855  // or not. (And not whether the pattern matched or not.)
856  return success();
857 }
858 
859 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
860  Value alloc) {
861  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
862  .getOperation();
863 }
864 
865 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
866  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // DeallocOp
871 //===----------------------------------------------------------------------===//
872 
873 LogicalResult DeallocOp::inferReturnTypes(
874  MLIRContext *context, std::optional<::mlir::Location> location,
875  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
876  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
877  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
878  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
879  IntegerType::get(context, 1));
880  return success();
881 }
882 
884  if (getMemrefs().size() != getConditions().size())
885  return emitOpError(
886  "must have the same number of conditions as memrefs to deallocate");
887  if (getRetained().size() != getUpdatedConditions().size())
888  return emitOpError("must have the same number of updated conditions "
889  "(results) as retained operands");
890  return success();
891 }
892 
893 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
894  ValueRange memrefs,
895  ValueRange conditions,
896  PatternRewriter &rewriter) {
897  if (deallocOp.getMemrefs() == memrefs &&
898  deallocOp.getConditions() == conditions)
899  return failure();
900 
901  rewriter.modifyOpInPlace(deallocOp, [&]() {
902  deallocOp.getMemrefsMutable().assign(memrefs);
903  deallocOp.getConditionsMutable().assign(conditions);
904  });
905  return success();
906 }
907 
908 namespace {
909 
910 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
911 /// make sure the corresponding condition value is updated accordingly since
912 /// their two conditions might not cover the same set of cases. In that case, we
913 /// have to combine them (by computing the disjunction of them).
914 /// Example:
915 /// ```mlir
916 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
917 /// ```
918 /// is canonicalized to
919 /// ```mlir
920 /// %0 = arith.ori %arg1, %arg2 : i1
921 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
922 /// ```
923 struct DeallocRemoveDuplicateDeallocMemrefs
924  : public OpRewritePattern<DeallocOp> {
926 
927  LogicalResult matchAndRewrite(DeallocOp deallocOp,
928  PatternRewriter &rewriter) const override {
929  // Unique memrefs to be deallocated.
930  DenseMap<Value, unsigned> memrefToCondition;
931  SmallVector<Value> newMemrefs, newConditions;
932  for (auto [i, memref, cond] :
933  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
934  if (memrefToCondition.count(memref)) {
935  // If the dealloc conditions don't match, we need to make sure that the
936  // dealloc happens on the union of cases.
937  Value &newCond = newConditions[memrefToCondition[memref]];
938  if (newCond != cond)
939  newCond =
940  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
941  } else {
942  memrefToCondition.insert({memref, newConditions.size()});
943  newMemrefs.push_back(memref);
944  newConditions.push_back(cond);
945  }
946  }
947 
948  // Return failure if we don't change anything such that we don't run into an
949  // infinite loop of pattern applications.
950  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
951  rewriter);
952  }
953 };
954 
955 /// Remove duplicate values in the list of retained memrefs. We need to make
956 /// sure the corresponding result condition value is replaced properly.
957 /// Example:
958 /// ```mlir
959 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
960 /// ```
961 /// is canonicalized to
962 /// ```mlir
963 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
964 /// ```
965 struct DeallocRemoveDuplicateRetainedMemrefs
966  : public OpRewritePattern<DeallocOp> {
968 
969  LogicalResult matchAndRewrite(DeallocOp deallocOp,
970  PatternRewriter &rewriter) const override {
971  // Unique retained values
973  SmallVector<Value> newRetained;
974  SmallVector<unsigned> resultReplacementIdx;
975  unsigned i = 0;
976  for (auto retained : deallocOp.getRetained()) {
977  if (seen.count(retained)) {
978  resultReplacementIdx.push_back(seen[retained]);
979  continue;
980  }
981 
982  seen[retained] = i;
983  newRetained.push_back(retained);
984  resultReplacementIdx.push_back(i++);
985  }
986 
987  // Return failure if we don't change anything such that we don't run into an
988  // infinite loop of pattern applications.
989  if (newRetained.size() == deallocOp.getRetained().size())
990  return failure();
991 
992  // We need to create a new op because the number of results is always the
993  // same as the number of condition operands.
994  auto newDeallocOp =
995  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
996  deallocOp.getConditions(), newRetained);
997  SmallVector<Value> replacements(
998  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
999  return newDeallocOp.getUpdatedConditions()[idx];
1000  }));
1001  rewriter.replaceOp(deallocOp, replacements);
1002  return success();
1003  }
1004 };
1005 
1006 /// Erase deallocation operations where the variadic list of memrefs to
1007 /// deallocate is empty. Example:
1008 /// ```mlir
1009 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1010 /// ```
1011 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1013 
1014  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1015  PatternRewriter &rewriter) const override {
1016  if (deallocOp.getMemrefs().empty()) {
1017  Value constFalse = rewriter.create<arith::ConstantOp>(
1018  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1019  rewriter.replaceOp(
1020  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1021  constFalse));
1022  return success();
1023  }
1024  return failure();
1025  }
1026 };
1027 
1028 /// Removes memrefs from the deallocation list if their associated condition is
1029 /// always 'false'.
1030 ///
1031 /// Example:
1032 /// ```
1033 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1034 /// if (%arg2, %false)
1035 /// ```
1036 /// becomes
1037 /// ```
1038 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1039 /// ```
1040 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1042 
1043  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1044  PatternRewriter &rewriter) const override {
1045  SmallVector<Value> newMemrefs, newConditions;
1046  for (auto [memref, cond] :
1047  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1048  if (!matchPattern(cond, m_Zero())) {
1049  newMemrefs.push_back(memref);
1050  newConditions.push_back(cond);
1051  }
1052  }
1053 
1054  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1055  rewriter);
1056  }
1057 };
1058 
1059 /// The `memref.extract_strided_metadata` is often inserted to get the base
1060 /// memref if the operand is not already guaranteed to be the result of a memref
1061 /// allocation operation. This canonicalization pattern removes this extraction
1062 /// operation if the operand is now produced by an allocation operation (e.g.,
1063 /// due to other canonicalizations simplifying the IR).
1064 ///
1065 /// Example:
1066 /// ```mlir
1067 /// %alloc = memref.alloc() : memref<2xi32>
1068 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1069 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1070 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1071 /// ```
1072 /// is canonicalized to
1073 /// ```mlir
1074 /// %alloc = memref.alloc() : memref<2xi32>
1075 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1076 /// ```
1077 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1079 
1080  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1081  PatternRewriter &rewriter) const override {
1082  SmallVector<Value> newMemrefs(
1083  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1084  auto extractStridedOp =
1085  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1086  if (!extractStridedOp)
1087  return memref;
1088  Value allocMemref = extractStridedOp.getOperand();
1089  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1090  if (!allocOp)
1091  return memref;
1092  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1093  return allocMemref;
1094  return memref;
1095  }));
1096 
1097  return updateDeallocIfChanged(deallocOp, newMemrefs,
1098  deallocOp.getConditions(), rewriter);
1099  }
1100 };
1101 
1102 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1103 /// other user of the allocated value and the allocating operation can be safely
1104 /// removed. If the same value is present multiple times, this pattern relies on
1105 /// other canonicalization patterns to remove the duplicate first.
1106 ///
1107 /// Example:
1108 /// ```mlir
1109 /// %alloc = memref.alloc() : memref<2xi32>
1110 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1111 /// ```
1112 /// is canonicalized to
1113 /// ```mlir
1114 /// bufferization.dealloc (%arg0 : ...) if (%true)
1115 /// ```
1116 struct RemoveAllocDeallocPairWhenNoOtherUsers
1117  : public OpRewritePattern<DeallocOp> {
1119 
1120  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1121  PatternRewriter &rewriter) const override {
1122  SmallVector<Value> newMemrefs, newConditions;
1123  SmallVector<Operation *> toDelete;
1124  for (auto [memref, cond] :
1125  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1126  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1127  // Check that it is indeed an allocate effect, that the op has no other
1128  // side effects (which would not allow us to remove the op), and that
1129  // there are no other users.
1130  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1131  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1132  memref.hasOneUse()) {
1133  toDelete.push_back(allocOp);
1134  continue;
1135  }
1136  }
1137 
1138  newMemrefs.push_back(memref);
1139  newConditions.push_back(cond);
1140  }
1141 
1142  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1143  rewriter)))
1144  return failure();
1145 
1146  for (Operation *op : toDelete)
1147  rewriter.eraseOp(op);
1148 
1149  return success();
1150  }
1151 };
1152 
1153 } // anonymous namespace
1154 
1155 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1156  MLIRContext *context) {
1158 }
1159 
1161  RewritePatternSet &patterns, MLIRContext *context) {
1162  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1163  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1164  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1165  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // TableGen'd op method definitions
1170 //===----------------------------------------------------------------------===//
1171 
1172 #define GET_OP_CLASSES
1173 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:631
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.