MLIR  19.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type, rank and memory space must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getMemorySpace() != destType.getMemorySpace())
35  return failure();
36  if (srcType.getRank() != destType.getRank())
37  return failure();
38 
39  // In case the affine maps are different, we may need to use a copy if we go
40  // from dynamic to static offset or stride (the canonicalization cannot know
41  // at this point that it is really cast compatible).
42  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43  int64_t sourceOffset, targetOffset;
44  SmallVector<int64_t, 4> sourceStrides, targetStrides;
45  if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46  failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47  return false;
48  auto dynamicToStatic = [](int64_t a, int64_t b) {
49  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50  };
51  if (dynamicToStatic(sourceOffset, targetOffset))
52  return false;
53  for (auto it : zip(sourceStrides, targetStrides))
54  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55  return false;
56  return true;
57  };
58 
59  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60  // ensure that we only generate casts that always succeed at runtime, we check
61  // a fix extra conditions in `isGuaranteedCastCompatible`.
62  if (memref::CastOp::areCastCompatible(srcType, destType) &&
63  isGuaranteedCastCompatible(srcType, destType)) {
64  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65  return casted;
66  }
67 
68  auto loc = value.getLoc();
69  SmallVector<Value, 4> dynamicOperands;
70  for (int i = 0; i < destType.getRank(); ++i) {
71  if (destType.getShape()[i] != ShapedType::kDynamic)
72  continue;
73  Value size = b.create<memref::DimOp>(loc, value, i);
74  dynamicOperands.push_back(size);
75  }
76 
78  options.createAlloc(b, loc, destType, dynamicOperands);
79  if (failed(copy))
80  return failure();
81  if (failed(options.createMemCpy(b, loc, value, *copy)))
82  return failure();
83  return copy;
84 }
85 
86 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87 /// to_memref op are different, a memref.cast is needed.
89  RewriterBase &rewriter, ToMemrefOp toMemref,
91  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92  if (!memrefToTensor)
93  return failure();
94 
95  Type srcType = memrefToTensor.getMemref().getType();
96  Type destType = toMemref.getType();
97 
98  // Directly rewrite if the type did not change.
99  if (srcType == destType) {
100  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101  return success();
102  }
103 
104  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 
108  // Ranked memref -> Ranked memref cast.
109  if (rankedSrcType && rankedDestType) {
111  rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112  if (failed(replacement))
113  return failure();
114 
115  rewriter.replaceOp(toMemref, *replacement);
116  return success();
117  }
118 
119  // Unranked memref -> Ranked memref cast: May require a copy.
120  // TODO: Not implemented at the moment.
121  if (unrankedSrcType && rankedDestType)
122  return failure();
123 
124  // Unranked memref -> unranked memref cast
125  // Ranked memref -> unranked memref cast: No copy needed.
126  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127  "expected that types are cast compatible");
128  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
129  memrefToTensor.getMemref());
130  return success();
131 }
132 
134  OpBuilder &b, Location loc, Value shapedValue,
135  SmallVector<Value> &dynamicDims) {
136  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138  if (shapedType.isDynamicDim(i)) {
139  if (llvm::isa<MemRefType>(shapedType)) {
140  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141  } else {
142  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144  }
145  }
146  }
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // AllocTensorOp
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154  const BufferizationOptions &options) {
155  OpBuilder::InsertionGuard g(rewriter);
156  Location loc = getLoc();
157 
158  // Nothing to do for dead AllocTensorOps.
159  if (getOperation()->getUses().empty()) {
160  rewriter.eraseOp(getOperation());
161  return success();
162  }
163 
164  // Get "copy" buffer.
165  Value copyBuffer;
166  if (getCopy()) {
167  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
172 
173  // Create memory allocation.
174  auto allocType = bufferization::getBufferType(getResult(), options);
175  if (failed(allocType))
176  return failure();
177  SmallVector<Value> dynamicDims = getDynamicSizes();
178  if (getCopy()) {
179  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181  }
182  FailureOr<Value> alloc = options.createAlloc(
183  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184  if (failed(alloc))
185  return failure();
186 
187  // Create memory copy (if any).
188  if (getCopy()) {
189  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190  return failure();
191  }
192 
193  // Replace op.
194  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195 
196  return success();
197 }
198 
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200  const AnalysisState &state) {
201  // AllocTensorOps do not write unless they have a `copy` value.
202  return static_cast<bool>(getCopy());
203 }
204 
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206  const AnalysisState &state) {
207  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208  "expected copy operand");
209  return true;
210 }
211 
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213  const AnalysisState &state) {
214  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215  "expected copy operand");
216  return false;
217 }
218 
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220  const AnalysisState &state) {
221  // This is a new allocation. It does not alias with any other buffer.
222  return {};
223 }
224 
227  SmallVector<Value> &invocationStack) {
228  assert(value == getResult() && "invalid value");
229 
230  // Compute memory space of this allocation.
231  Attribute memorySpace;
232  if (getMemorySpace().has_value()) {
233  memorySpace = *getMemorySpace();
234  } else if (getCopy()) {
235  auto copyBufferType =
236  bufferization::getBufferType(getCopy(), options, invocationStack);
237  if (failed(copyBufferType))
238  return failure();
239  memorySpace = copyBufferType->getMemorySpace();
240  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241  memorySpace = *ms;
242  } else {
243  return getOperation()->emitError("could not infer memory space");
244  }
245 
246  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247 }
248 
250  if (getCopy() && !getDynamicSizes().empty())
251  return emitError("dynamic sizes not needed when copying a tensor");
252  if (!getCopy() && getType().getNumDynamicDims() !=
253  static_cast<int64_t>(getDynamicSizes().size()))
254  return emitError("expected ")
255  << getType().getNumDynamicDims() << " dynamic sizes";
256  if (getCopy() && getCopy().getType() != getType())
257  return emitError("expected that `copy` and return type match");
258  return success();
259 }
260 
261 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
262  RankedTensorType type, ValueRange dynamicSizes) {
263  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
264  /*size_hint=*/Value(),
265  /*memory_space=*/IntegerAttr());
266 }
267 
268 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
269  RankedTensorType type, ValueRange dynamicSizes,
270  Value copy) {
271  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
272  /*memory_space=*/IntegerAttr());
273 }
274 
275 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
276  TensorType type, ValueRange dynamicSizes, Value copy,
277  IntegerAttr memorySpace) {
278  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
279  memorySpace);
280 }
281 
282 namespace {
283 /// Change the type of the result of a `bufferization.alloc_tensor` by making
284 /// the result type statically sized along dimension that in the original
285 /// operation where defined as dynamic, but the size was defined using a
286 /// `constant` op. For example:
287 ///
288 /// %c5 = arith.constant 5: index
289 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
290 ///
291 /// to
292 ///
293 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
294 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
296 
297  LogicalResult matchAndRewrite(AllocTensorOp op,
298  PatternRewriter &rewriter) const override {
299  if (op.getCopy())
300  return failure();
301  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
302  SmallVector<Value> newDynamicSizes;
303  unsigned int dynValCounter = 0;
304  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
305  if (!op.isDynamicDim(i))
306  continue;
307  Value value = op.getDynamicSizes()[dynValCounter++];
308  APInt intVal;
309  if (matchPattern(value, m_ConstantInt(&intVal))) {
310  int64_t dim = intVal.getSExtValue();
311  if (dim >= 0)
312  newShape[i] = intVal.getSExtValue();
313  else
314  newDynamicSizes.push_back(value);
315  } else {
316  newDynamicSizes.push_back(value);
317  }
318  }
319  RankedTensorType newType = RankedTensorType::get(
320  newShape, op.getType().getElementType(), op.getType().getEncoding());
321  if (newType == op.getType())
322  return failure();
323  auto newOp = rewriter.create<AllocTensorOp>(
324  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
325  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
326  return success();
327  }
328 };
329 
330 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
332 
333  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
334  PatternRewriter &rewriter) const override {
335  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
336  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
337  if (!allocTensorOp || !maybeConstantIndex)
338  return failure();
339  if (*maybeConstantIndex < 0 ||
340  *maybeConstantIndex >= allocTensorOp.getType().getRank())
341  return failure();
342  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
343  return failure();
344  rewriter.replaceOp(
345  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
346  return success();
347  }
348 };
349 } // namespace
350 
351 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
352  MLIRContext *ctx) {
353  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
354 }
355 
357  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
358  auto shapes = llvm::to_vector<4>(
359  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
360  [&](int64_t dim) -> OpFoldResult {
361  if (isDynamicDim(dim))
362  return getDynamicSize(builder, dim);
363  return builder.getIndexAttr(getStaticSize(dim));
364  }));
365  reifiedReturnShapes.emplace_back(std::move(shapes));
366  return success();
367 }
368 
370  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
371  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
372  parser.parseRParen())
373  return failure();
374  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
375  OpAsmParser::UnresolvedOperand copyOperand;
376  if (copyKeyword.succeeded())
377  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
378  parser.parseRParen())
379  return failure();
380  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
381  OpAsmParser::UnresolvedOperand sizeHintOperand;
382  if (sizeHintKeyword.succeeded())
383  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
384  return failure();
385  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
386  return failure();
387 
388  TensorType type;
389  if (parser.parseCustomTypeWithFallback(type))
390  return failure();
391  result.addTypes(type);
392 
393  Type indexType = parser.getBuilder().getIndexType();
394  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
395  return failure();
396  if (copyKeyword.succeeded())
397  if (parser.resolveOperand(copyOperand, type, result.operands))
398  return failure();
399  if (sizeHintKeyword.succeeded())
400  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
401  return failure();
402  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
404  {static_cast<int32_t>(dynamicSizesOperands.size()),
405  static_cast<int32_t>(copyKeyword.succeeded()),
406  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
407  return success();
408 }
409 
411  p << "(" << getDynamicSizes() << ")";
412  if (getCopy())
413  p << " copy(" << getCopy() << ")";
414  if (getSizeHint())
415  p << " size_hint=" << getSizeHint();
416  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
417  AllocTensorOp::getOperandSegmentSizeAttr()});
418  p << " : ";
419  auto type = getResult().getType();
420  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
421  p.printStrippedAttrOrType(validType);
422  else
423  p << type;
424 }
425 
426 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
427  assert(isDynamicDim(idx) && "expected dynamic dim");
428  if (getCopy())
429  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
430  return getOperand(getIndexOfDynamicSize(idx));
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // CloneOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
438  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
439 }
440 
441 namespace {
442 
443 /// Merge the clone and its source (by converting the clone to a cast) when
444 /// possible.
445 struct SimplifyClones : public OpRewritePattern<CloneOp> {
447 
448  LogicalResult matchAndRewrite(CloneOp cloneOp,
449  PatternRewriter &rewriter) const override {
450  if (cloneOp.use_empty()) {
451  rewriter.eraseOp(cloneOp);
452  return success();
453  }
454 
455  Value source = cloneOp.getInput();
456  if (source.getType() != cloneOp.getType() &&
457  !memref::CastOp::areCastCompatible({source.getType()},
458  {cloneOp.getType()}))
459  return failure();
460 
461  // Aims to find the dealloc op for the canonical source
462  // which otherwise could prevent removal of unnecessary allocs.
463  Value canonicalSource = source;
464  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
465  canonicalSource.getDefiningOp()))
466  canonicalSource = iface.getViewSource();
467 
468  std::optional<Operation *> maybeCloneDeallocOp =
469  memref::findDealloc(cloneOp.getOutput());
470  // Skip if either of them has > 1 deallocate operations.
471  if (!maybeCloneDeallocOp.has_value())
472  return failure();
473  std::optional<Operation *> maybeSourceDeallocOp =
474  memref::findDealloc(canonicalSource);
475  if (!maybeSourceDeallocOp.has_value())
476  return failure();
477  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
478  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
479 
480  // If both are deallocated in the same block, their in-block lifetimes
481  // might not fully overlap, so we cannot decide which one to drop.
482  if (cloneDeallocOp && sourceDeallocOp &&
483  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
484  return failure();
485 
486  Block *currentBlock = cloneOp->getBlock();
487  Operation *redundantDealloc = nullptr;
488  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
489  redundantDealloc = cloneDeallocOp;
490  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
491  redundantDealloc = sourceDeallocOp;
492  }
493 
494  if (!redundantDealloc)
495  return failure();
496 
497  // Safety check that there are no other deallocations inbetween
498  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
499  // of source before the uses of the clone. With alias information, we could
500  // restrict this to only fail of the dealloc's operand is an alias
501  // of the source.
502  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
503  pos = pos->getNextNode()) {
504  // Bail if we run out of operations while looking for a deallocation op.
505  if (!pos)
506  return failure();
507  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
508  if (!effectInterface)
509  continue;
510  if (effectInterface.hasEffect<MemoryEffects::Free>())
511  return failure();
512  }
513 
514  if (source.getType() != cloneOp.getType())
515  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
516  cloneOp.getType(), source);
517  rewriter.replaceOp(cloneOp, source);
518  rewriter.eraseOp(redundantDealloc);
519  return success();
520  }
521 };
522 
523 } // namespace
524 
525 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
526  MLIRContext *context) {
527  results.add<SimplifyClones>(context);
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // DeallocTensorOp
532 //===----------------------------------------------------------------------===//
533 
534 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
535  const BufferizationOptions &options) {
536  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
537  if (failed(buffer))
538  return failure();
539  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
540  rewriter.eraseOp(getOperation());
541  return success();
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // MaterializeInDestinationOp
546 //===----------------------------------------------------------------------===//
547 
548 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
549  OpOperand &opOperand, const AnalysisState &state) {
550  return opOperand == getSourceMutable();
551 }
552 
553 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
554  OpOperand &opOperand, const AnalysisState &state) {
555  if (opOperand == getDestMutable()) {
556  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
557  return true;
558  }
559  return false;
560 }
561 
562 bool MaterializeInDestinationOp::mustBufferizeInPlace(
563  OpOperand &opOperand, const AnalysisState &state) {
564  // The source is only read and not written, so it always bufferizes in-place
565  // by default. The destination is written and is forced to bufferize in-place
566  // (if it is a tensor).
567  return true;
568 }
569 
571 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
572  const AnalysisState &state) {
573  if (opOperand == getDestMutable()) {
574  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
575  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
576  }
577  return {};
578 }
579 
581 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
582  const BufferizationOptions &options) {
583  bool tensorDest = isa<TensorType>(getDest().getType());
584  Value buffer;
585  if (tensorDest) {
586  FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
587  if (failed(maybeBuffer))
588  return failure();
589  buffer = *maybeBuffer;
590  } else {
591  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
592  buffer = getDest();
593  }
594  auto srcBuffer = getBuffer(rewriter, getSource(), options);
595  if (failed(srcBuffer))
596  return failure();
597  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
598  return failure();
599  replaceOpWithBufferizedValues(rewriter, getOperation(),
600  tensorDest ? ValueRange(buffer) : ValueRange());
601  return success();
602 }
603 
604 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
605  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
606  // As elements are copied from the "source" buffer to the "dest" buffer,
607  // already copied elements are not read a second time.
608  return true;
609 }
610 
612  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
613  if (getOperation()->getNumResults() == 1) {
614  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
615  reifiedReturnShapes.resize(1,
616  SmallVector<OpFoldResult>(getType().getRank()));
617  reifiedReturnShapes[0] =
618  tensor::getMixedSizes(builder, getLoc(), getDest());
619  }
620  return success();
621 }
622 
623 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
624  Location loc) {
625  if (isa<TensorType>(getDest().getType())) {
626  // The subset is the entire destination tensor.
627  return getDest();
628  }
629 
630  // The "restrict" attribute is transferred from this op to the newly created
631  // to_tensor op. If this op does not the "restrict" attribute, the subset
632  // extraction cannot be built because there is no guarantee that there is no
633  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
634  if (!getRestrict())
635  return {};
636 
637  // Build a bufferization.to_tensor op.
638  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
639  assert(getRestrict() &&
640  "expected that ops with memrefs dest have 'restrict'");
641  setRestrict(false);
642  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
643  getWritable());
644 }
645 
646 bool MaterializeInDestinationOp::isEquivalentSubset(
647  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
648  return equivalenceFn(getDest(), candidate);
649 }
650 
652 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
653  return {getDest()};
654 }
655 
656 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
657  return getOperation()->getOpOperand(0) /*source*/;
658 }
659 
660 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
661  SubsetOpInterface subsetOp,
662  function_ref<bool(Value, Value)> equivalenceFn) {
663  return false;
664 }
665 
666 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
667  SubsetOpInterface subsetOp,
668  function_ref<bool(Value, Value)> equivalenceFn) {
669  return false;
670 }
671 
673  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
674  return emitOpError("'dest' must be a tensor or a memref");
675  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
676  if (getOperation()->getNumResults() != 1)
677  return emitOpError("tensor 'dest' implies exactly one tensor result");
678  if (destType != getResult().getType())
679  return emitOpError("result and 'dest' types must match");
680  }
681  if (isa<BaseMemRefType>(getDest().getType()) &&
682  getOperation()->getNumResults() != 0)
683  return emitOpError("memref 'dest' implies zero results");
684  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
685  return emitOpError("'restrict' is valid only for memref destinations");
686  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
687  return emitOpError("'writable' must be specified if and only if the "
688  "destination is of memref type");
689  return success();
690 }
691 
692 void MaterializeInDestinationOp::build(OpBuilder &builder,
693  OperationState &state, Value source,
694  Value dest) {
695  auto destTensorType = dyn_cast<TensorType>(dest.getType());
696  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
697  source, dest);
698 }
699 
700 bool MaterializeInDestinationOp::isWritable(Value value,
701  const AnalysisState &state) {
702  return isa<TensorType>(getDest().getType()) ? true : getWritable();
703 }
704 
705 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
706  return getDestMutable();
707 }
708 
709 void MaterializeInDestinationOp::getEffects(
711  &effects) {
712  if (isa<BaseMemRefType>(getDest().getType()))
713  effects.emplace_back(MemoryEffects::Write::get(), getDest(),
715 }
716 
717 //===----------------------------------------------------------------------===//
718 // ToTensorOp
719 //===----------------------------------------------------------------------===//
720 
721 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
722  return getWritable();
723 }
724 
725 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
726  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
727  // Approximate alias analysis by conservatively folding only when no there
728  // is no interleaved operation.
729  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
730  toMemref->getNextNode() == this->getOperation())
731  return toMemref.getTensor();
732  return {};
733 }
734 
735 namespace {
736 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
738 
739  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
740  PatternRewriter &rewriter) const override {
741  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
742  if (!memrefToTensorOp)
743  return failure();
744 
745  rewriter.replaceOpWithNewOp<memref::DimOp>(
746  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
747  return success();
748  }
749 };
750 } // namespace
751 
752 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
753  MLIRContext *context) {
754  results.add<DimOfToTensorFolder>(context);
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // ToMemrefOp
759 //===----------------------------------------------------------------------===//
760 
761 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
762  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
763  if (memrefToTensor.getMemref().getType() == getType())
764  return memrefToTensor.getMemref();
765  return {};
766 }
767 
768 namespace {
769 
770 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
771 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
773 
774  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
775  PatternRewriter &rewriter) const final {
776  auto tensorCastOperand =
777  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
778  if (!tensorCastOperand)
779  return failure();
780  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
781  tensorCastOperand.getOperand().getType());
782  if (!srcTensorType)
783  return failure();
784  auto memrefType = MemRefType::get(srcTensorType.getShape(),
785  srcTensorType.getElementType());
786  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
787  tensorCastOperand.getOperand());
788  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
789  memref);
790  return success();
791  }
792 };
793 
794 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
795 /// cast if necessary.
796 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
798 
799  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
800  PatternRewriter &rewriter) const final {
802  options.bufferAlignment = 0;
803  return foldToMemrefToTensorPair(rewriter, toMemref, options);
804  }
805 };
806 
807 /// Fold a load on a to_memref operation into an tensor.extract on the
808 /// corresponding tensor.
809 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
811 
812  LogicalResult matchAndRewrite(memref::LoadOp load,
813  PatternRewriter &rewriter) const override {
814  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
815  if (!toMemref)
816  return failure();
817 
818  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
819  load.getIndices());
820  return success();
821  }
822 };
823 
824 /// Fold dim of a to_memref into the dim of the tensor.
825 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
827 
828  LogicalResult matchAndRewrite(memref::DimOp dimOp,
829  PatternRewriter &rewriter) const override {
830  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
831  if (!castOp)
832  return failure();
833  Value newSource = castOp.getOperand();
834  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
835  dimOp.getIndex());
836  return success();
837  }
838 };
839 
840 } // namespace
841 
842 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
843  MLIRContext *context) {
844  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
845  ToMemrefToTensorFolding>(context);
846 }
847 
848 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
849  const BufferizationOptions &options) {
850  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
851  (void)foldToMemrefToTensorPair(rewriter, *this, options);
852  // Note: The return value of `bufferize` indicates whether there was an error
853  // or not. (And not whether the pattern matched or not.)
854  return success();
855 }
856 
857 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
858  Value alloc) {
859  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
860  .getOperation();
861 }
862 
863 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
864  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // DeallocOp
869 //===----------------------------------------------------------------------===//
870 
871 LogicalResult DeallocOp::inferReturnTypes(
872  MLIRContext *context, std::optional<::mlir::Location> location,
873  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
874  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
875  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
876  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
877  IntegerType::get(context, 1));
878  return success();
879 }
880 
882  if (getMemrefs().size() != getConditions().size())
883  return emitOpError(
884  "must have the same number of conditions as memrefs to deallocate");
885  if (getRetained().size() != getUpdatedConditions().size())
886  return emitOpError("must have the same number of updated conditions "
887  "(results) as retained operands");
888  return success();
889 }
890 
891 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
892  ValueRange memrefs,
893  ValueRange conditions,
894  PatternRewriter &rewriter) {
895  if (deallocOp.getMemrefs() == memrefs &&
896  deallocOp.getConditions() == conditions)
897  return failure();
898 
899  rewriter.modifyOpInPlace(deallocOp, [&]() {
900  deallocOp.getMemrefsMutable().assign(memrefs);
901  deallocOp.getConditionsMutable().assign(conditions);
902  });
903  return success();
904 }
905 
906 namespace {
907 
908 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
909 /// make sure the corresponding condition value is updated accordingly since
910 /// their two conditions might not cover the same set of cases. In that case, we
911 /// have to combine them (by computing the disjunction of them).
912 /// Example:
913 /// ```mlir
914 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
915 /// ```
916 /// is canonicalized to
917 /// ```mlir
918 /// %0 = arith.ori %arg1, %arg2 : i1
919 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
920 /// ```
921 struct DeallocRemoveDuplicateDeallocMemrefs
922  : public OpRewritePattern<DeallocOp> {
924 
925  LogicalResult matchAndRewrite(DeallocOp deallocOp,
926  PatternRewriter &rewriter) const override {
927  // Unique memrefs to be deallocated.
928  DenseMap<Value, unsigned> memrefToCondition;
929  SmallVector<Value> newMemrefs, newConditions;
930  for (auto [i, memref, cond] :
931  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
932  if (memrefToCondition.count(memref)) {
933  // If the dealloc conditions don't match, we need to make sure that the
934  // dealloc happens on the union of cases.
935  Value &newCond = newConditions[memrefToCondition[memref]];
936  if (newCond != cond)
937  newCond =
938  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
939  } else {
940  memrefToCondition.insert({memref, newConditions.size()});
941  newMemrefs.push_back(memref);
942  newConditions.push_back(cond);
943  }
944  }
945 
946  // Return failure if we don't change anything such that we don't run into an
947  // infinite loop of pattern applications.
948  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
949  rewriter);
950  }
951 };
952 
953 /// Remove duplicate values in the list of retained memrefs. We need to make
954 /// sure the corresponding result condition value is replaced properly.
955 /// Example:
956 /// ```mlir
957 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
958 /// ```
959 /// is canonicalized to
960 /// ```mlir
961 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
962 /// ```
963 struct DeallocRemoveDuplicateRetainedMemrefs
964  : public OpRewritePattern<DeallocOp> {
966 
967  LogicalResult matchAndRewrite(DeallocOp deallocOp,
968  PatternRewriter &rewriter) const override {
969  // Unique retained values
971  SmallVector<Value> newRetained;
972  SmallVector<unsigned> resultReplacementIdx;
973  unsigned i = 0;
974  for (auto retained : deallocOp.getRetained()) {
975  if (seen.count(retained)) {
976  resultReplacementIdx.push_back(seen[retained]);
977  continue;
978  }
979 
980  seen[retained] = i;
981  newRetained.push_back(retained);
982  resultReplacementIdx.push_back(i++);
983  }
984 
985  // Return failure if we don't change anything such that we don't run into an
986  // infinite loop of pattern applications.
987  if (newRetained.size() == deallocOp.getRetained().size())
988  return failure();
989 
990  // We need to create a new op because the number of results is always the
991  // same as the number of condition operands.
992  auto newDeallocOp =
993  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
994  deallocOp.getConditions(), newRetained);
995  SmallVector<Value> replacements(
996  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
997  return newDeallocOp.getUpdatedConditions()[idx];
998  }));
999  rewriter.replaceOp(deallocOp, replacements);
1000  return success();
1001  }
1002 };
1003 
1004 /// Erase deallocation operations where the variadic list of memrefs to
1005 /// deallocate is empty. Example:
1006 /// ```mlir
1007 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1008 /// ```
1009 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1011 
1012  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1013  PatternRewriter &rewriter) const override {
1014  if (deallocOp.getMemrefs().empty()) {
1015  Value constFalse = rewriter.create<arith::ConstantOp>(
1016  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1017  rewriter.replaceOp(
1018  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1019  constFalse));
1020  return success();
1021  }
1022  return failure();
1023  }
1024 };
1025 
1026 /// Removes memrefs from the deallocation list if their associated condition is
1027 /// always 'false'.
1028 ///
1029 /// Example:
1030 /// ```
1031 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1032 /// if (%arg2, %false)
1033 /// ```
1034 /// becomes
1035 /// ```
1036 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1037 /// ```
1038 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1040 
1041  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1042  PatternRewriter &rewriter) const override {
1043  SmallVector<Value> newMemrefs, newConditions;
1044  for (auto [memref, cond] :
1045  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1046  if (!matchPattern(cond, m_Zero())) {
1047  newMemrefs.push_back(memref);
1048  newConditions.push_back(cond);
1049  }
1050  }
1051 
1052  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1053  rewriter);
1054  }
1055 };
1056 
1057 /// The `memref.extract_strided_metadata` is often inserted to get the base
1058 /// memref if the operand is not already guaranteed to be the result of a memref
1059 /// allocation operation. This canonicalization pattern removes this extraction
1060 /// operation if the operand is now produced by an allocation operation (e.g.,
1061 /// due to other canonicalizations simplifying the IR).
1062 ///
1063 /// Example:
1064 /// ```mlir
1065 /// %alloc = memref.alloc() : memref<2xi32>
1066 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1067 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1068 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1069 /// ```
1070 /// is canonicalized to
1071 /// ```mlir
1072 /// %alloc = memref.alloc() : memref<2xi32>
1073 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1074 /// ```
1075 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1077 
1078  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1079  PatternRewriter &rewriter) const override {
1080  SmallVector<Value> newMemrefs(
1081  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1082  auto extractStridedOp =
1083  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1084  if (!extractStridedOp)
1085  return memref;
1086  Value allocMemref = extractStridedOp.getOperand();
1087  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1088  if (!allocOp)
1089  return memref;
1090  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1091  return allocMemref;
1092  return memref;
1093  }));
1094 
1095  return updateDeallocIfChanged(deallocOp, newMemrefs,
1096  deallocOp.getConditions(), rewriter);
1097  }
1098 };
1099 
1100 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1101 /// other user of the allocated value and the allocating operation can be safely
1102 /// removed. If the same value is present multiple times, this pattern relies on
1103 /// other canonicalization patterns to remove the duplicate first.
1104 ///
1105 /// Example:
1106 /// ```mlir
1107 /// %alloc = memref.alloc() : memref<2xi32>
1108 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1109 /// ```
1110 /// is canonicalized to
1111 /// ```mlir
1112 /// bufferization.dealloc (%arg0 : ...) if (%true)
1113 /// ```
1114 struct RemoveAllocDeallocPairWhenNoOtherUsers
1115  : public OpRewritePattern<DeallocOp> {
1117 
1118  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1119  PatternRewriter &rewriter) const override {
1120  SmallVector<Value> newMemrefs, newConditions;
1121  SmallVector<Operation *> toDelete;
1122  for (auto [memref, cond] :
1123  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1124  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1125  // Check that it is indeed an allocate effect, that the op has no other
1126  // side effects (which would not allow us to remove the op), and that
1127  // there are no other users.
1128  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1129  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1130  memref.hasOneUse()) {
1131  toDelete.push_back(allocOp);
1132  continue;
1133  }
1134  }
1135 
1136  newMemrefs.push_back(memref);
1137  newConditions.push_back(cond);
1138  }
1139 
1140  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1141  rewriter)))
1142  return failure();
1143 
1144  for (Operation *op : toDelete)
1145  rewriter.eraseOp(op);
1146 
1147  return success();
1148  }
1149 };
1150 
1151 } // anonymous namespace
1152 
1153 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1154  MLIRContext *context) {
1156 }
1157 
1159  RewritePatternSet &patterns, MLIRContext *context) {
1160  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1161  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1162  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1163  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // TableGen'd op method definitions
1168 //===----------------------------------------------------------------------===//
1169 
1170 #define GET_OP_CLASSES
1171 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.