MLIR  19.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type, rank and memory space must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getMemorySpace() != destType.getMemorySpace())
35  return failure();
36  if (srcType.getRank() != destType.getRank())
37  return failure();
38 
39  // In case the affine maps are different, we may need to use a copy if we go
40  // from dynamic to static offset or stride (the canonicalization cannot know
41  // at this point that it is really cast compatible).
42  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43  int64_t sourceOffset, targetOffset;
44  SmallVector<int64_t, 4> sourceStrides, targetStrides;
45  if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46  failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47  return false;
48  auto dynamicToStatic = [](int64_t a, int64_t b) {
49  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50  };
51  if (dynamicToStatic(sourceOffset, targetOffset))
52  return false;
53  for (auto it : zip(sourceStrides, targetStrides))
54  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55  return false;
56  return true;
57  };
58 
59  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60  // ensure that we only generate casts that always succeed at runtime, we check
61  // a fix extra conditions in `isGuaranteedCastCompatible`.
62  if (memref::CastOp::areCastCompatible(srcType, destType) &&
63  isGuaranteedCastCompatible(srcType, destType)) {
64  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65  return casted;
66  }
67 
68  auto loc = value.getLoc();
69  SmallVector<Value, 4> dynamicOperands;
70  for (int i = 0; i < destType.getRank(); ++i) {
71  if (destType.getShape()[i] != ShapedType::kDynamic)
72  continue;
73  Value size = b.create<memref::DimOp>(loc, value, i);
74  dynamicOperands.push_back(size);
75  }
76 
77  FailureOr<Value> copy =
78  options.createAlloc(b, loc, destType, dynamicOperands);
79  if (failed(copy))
80  return failure();
81  if (failed(options.createMemCpy(b, loc, value, *copy)))
82  return failure();
83  return copy;
84 }
85 
86 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87 /// to_memref op are different, a memref.cast is needed.
89  RewriterBase &rewriter, ToMemrefOp toMemref,
91  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92  if (!memrefToTensor)
93  return failure();
94 
95  Type srcType = memrefToTensor.getMemref().getType();
96  Type destType = toMemref.getType();
97 
98  // Directly rewrite if the type did not change.
99  if (srcType == destType) {
100  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101  return success();
102  }
103 
104  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 
108  // Ranked memref -> Ranked memref cast.
109  if (rankedSrcType && rankedDestType) {
110  FailureOr<Value> replacement = castOrReallocMemRefValue(
111  rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112  if (failed(replacement))
113  return failure();
114 
115  rewriter.replaceOp(toMemref, *replacement);
116  return success();
117  }
118 
119  // Unranked memref -> Ranked memref cast: May require a copy.
120  // TODO: Not implemented at the moment.
121  if (unrankedSrcType && rankedDestType)
122  return failure();
123 
124  // Unranked memref -> unranked memref cast
125  // Ranked memref -> unranked memref cast: No copy needed.
126  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127  "expected that types are cast compatible");
128  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
129  memrefToTensor.getMemref());
130  return success();
131 }
132 
134  OpBuilder &b, Location loc, Value shapedValue,
135  SmallVector<Value> &dynamicDims) {
136  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138  if (shapedType.isDynamicDim(i)) {
139  if (llvm::isa<MemRefType>(shapedType)) {
140  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141  } else {
142  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144  }
145  }
146  }
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // AllocTensorOp
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154  const BufferizationOptions &options) {
155  OpBuilder::InsertionGuard g(rewriter);
156  Location loc = getLoc();
157 
158  // Nothing to do for dead AllocTensorOps.
159  if (getOperation()->getUses().empty()) {
160  rewriter.eraseOp(getOperation());
161  return success();
162  }
163 
164  // Get "copy" buffer.
165  Value copyBuffer;
166  if (getCopy()) {
167  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
172 
173  // Create memory allocation.
174  auto allocType = bufferization::getBufferType(getResult(), options);
175  if (failed(allocType))
176  return failure();
177  SmallVector<Value> dynamicDims = getDynamicSizes();
178  if (getCopy()) {
179  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181  }
182  FailureOr<Value> alloc = options.createAlloc(
183  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184  if (failed(alloc))
185  return failure();
186 
187  // Create memory copy (if any).
188  if (getCopy()) {
189  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190  return failure();
191  }
192 
193  // Replace op.
194  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195 
196  return success();
197 }
198 
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200  const AnalysisState &state) {
201  // AllocTensorOps do not write unless they have a `copy` value.
202  return static_cast<bool>(getCopy());
203 }
204 
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206  const AnalysisState &state) {
207  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208  "expected copy operand");
209  return true;
210 }
211 
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213  const AnalysisState &state) {
214  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215  "expected copy operand");
216  return false;
217 }
218 
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220  const AnalysisState &state) {
221  // This is a new allocation. It does not alias with any other buffer.
222  return {};
223 }
224 
225 FailureOr<BaseMemRefType>
227  SmallVector<Value> &invocationStack) {
228  assert(value == getResult() && "invalid value");
229 
230  // Compute memory space of this allocation.
231  Attribute memorySpace;
232  if (getMemorySpace().has_value()) {
233  memorySpace = *getMemorySpace();
234  } else if (getCopy()) {
235  auto copyBufferType =
236  bufferization::getBufferType(getCopy(), options, invocationStack);
237  if (failed(copyBufferType))
238  return failure();
239  memorySpace = copyBufferType->getMemorySpace();
240  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241  memorySpace = *ms;
242  } else {
243  return getOperation()->emitError("could not infer memory space");
244  }
245 
246  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247 }
248 
249 LogicalResult AllocTensorOp::verify() {
250  if (getCopy() && !getDynamicSizes().empty())
251  return emitError("dynamic sizes not needed when copying a tensor");
252  if (!getCopy() && getType().getNumDynamicDims() !=
253  static_cast<int64_t>(getDynamicSizes().size()))
254  return emitError("expected ")
255  << getType().getNumDynamicDims() << " dynamic sizes";
256  if (getCopy() && getCopy().getType() != getType())
257  return emitError("expected that `copy` and return type match");
258  return success();
259 }
260 
261 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
262  RankedTensorType type, ValueRange dynamicSizes) {
263  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
264  /*size_hint=*/Value(),
265  /*memory_space=*/IntegerAttr());
266 }
267 
268 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
269  RankedTensorType type, ValueRange dynamicSizes,
270  Value copy) {
271  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
272  /*memory_space=*/IntegerAttr());
273 }
274 
275 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
276  TensorType type, ValueRange dynamicSizes, Value copy,
277  IntegerAttr memorySpace) {
278  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
279  memorySpace);
280 }
281 
282 namespace {
283 /// Change the type of the result of a `bufferization.alloc_tensor` by making
284 /// the result type statically sized along dimension that in the original
285 /// operation where defined as dynamic, but the size was defined using a
286 /// `constant` op. For example:
287 ///
288 /// %c5 = arith.constant 5: index
289 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
290 ///
291 /// to
292 ///
293 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
294 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
296 
297  LogicalResult matchAndRewrite(AllocTensorOp op,
298  PatternRewriter &rewriter) const override {
299  if (op.getCopy())
300  return failure();
301  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
302  SmallVector<Value> newDynamicSizes;
303  unsigned int dynValCounter = 0;
304  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
305  if (!op.isDynamicDim(i))
306  continue;
307  Value value = op.getDynamicSizes()[dynValCounter++];
308  APInt intVal;
309  if (matchPattern(value, m_ConstantInt(&intVal))) {
310  int64_t dim = intVal.getSExtValue();
311  if (dim >= 0)
312  newShape[i] = intVal.getSExtValue();
313  else
314  newDynamicSizes.push_back(value);
315  } else {
316  newDynamicSizes.push_back(value);
317  }
318  }
319  RankedTensorType newType = RankedTensorType::get(
320  newShape, op.getType().getElementType(), op.getType().getEncoding());
321  if (newType == op.getType())
322  return failure();
323  auto newOp = rewriter.create<AllocTensorOp>(
324  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
325  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
326  return success();
327  }
328 };
329 
330 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
332 
333  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
334  PatternRewriter &rewriter) const override {
335  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
336  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
337  if (!allocTensorOp || !maybeConstantIndex)
338  return failure();
339  if (*maybeConstantIndex < 0 ||
340  *maybeConstantIndex >= allocTensorOp.getType().getRank())
341  return failure();
342  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
343  return failure();
344  rewriter.replaceOp(
345  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
346  return success();
347  }
348 };
349 } // namespace
350 
351 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
352  MLIRContext *ctx) {
353  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
354 }
355 
357  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
358  auto shapes = llvm::to_vector<4>(
359  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
360  [&](int64_t dim) -> OpFoldResult {
361  if (isDynamicDim(dim))
362  return getDynamicSize(builder, dim);
363  return builder.getIndexAttr(getStaticSize(dim));
364  }));
365  reifiedReturnShapes.emplace_back(std::move(shapes));
366  return success();
367 }
368 
369 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
370  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
371  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
372  parser.parseRParen())
373  return failure();
374  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
375  OpAsmParser::UnresolvedOperand copyOperand;
376  if (copyKeyword.succeeded())
377  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
378  parser.parseRParen())
379  return failure();
380  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
381  OpAsmParser::UnresolvedOperand sizeHintOperand;
382  if (sizeHintKeyword.succeeded())
383  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
384  return failure();
385  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
386  return failure();
387 
388  TensorType type;
389  if (parser.parseCustomTypeWithFallback(type))
390  return failure();
391  result.addTypes(type);
392 
393  Type indexType = parser.getBuilder().getIndexType();
394  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
395  return failure();
396  if (copyKeyword.succeeded())
397  if (parser.resolveOperand(copyOperand, type, result.operands))
398  return failure();
399  if (sizeHintKeyword.succeeded())
400  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
401  return failure();
402  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
404  {static_cast<int32_t>(dynamicSizesOperands.size()),
405  static_cast<int32_t>(copyKeyword.succeeded()),
406  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
407  return success();
408 }
409 
411  p << "(" << getDynamicSizes() << ")";
412  if (getCopy())
413  p << " copy(" << getCopy() << ")";
414  if (getSizeHint())
415  p << " size_hint=" << getSizeHint();
416  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
417  AllocTensorOp::getOperandSegmentSizeAttr()});
418  p << " : ";
419  auto type = getResult().getType();
420  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
421  p.printStrippedAttrOrType(validType);
422  else
423  p << type;
424 }
425 
426 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
427  assert(isDynamicDim(idx) && "expected dynamic dim");
428  if (getCopy())
429  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
430  return getOperand(getIndexOfDynamicSize(idx));
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // CloneOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
438  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
439 }
440 
441 namespace {
442 
443 /// Merge the clone and its source (by converting the clone to a cast) when
444 /// possible.
445 struct SimplifyClones : public OpRewritePattern<CloneOp> {
447 
448  LogicalResult matchAndRewrite(CloneOp cloneOp,
449  PatternRewriter &rewriter) const override {
450  if (cloneOp.use_empty()) {
451  rewriter.eraseOp(cloneOp);
452  return success();
453  }
454 
455  Value source = cloneOp.getInput();
456  if (source.getType() != cloneOp.getType() &&
457  !memref::CastOp::areCastCompatible({source.getType()},
458  {cloneOp.getType()}))
459  return failure();
460 
461  // Aims to find the dealloc op for the canonical source
462  // which otherwise could prevent removal of unnecessary allocs.
463  Value canonicalSource = source;
464  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
465  canonicalSource.getDefiningOp()))
466  canonicalSource = iface.getViewSource();
467 
468  std::optional<Operation *> maybeCloneDeallocOp =
469  memref::findDealloc(cloneOp.getOutput());
470  // Skip if either of them has > 1 deallocate operations.
471  if (!maybeCloneDeallocOp.has_value())
472  return failure();
473  std::optional<Operation *> maybeSourceDeallocOp =
474  memref::findDealloc(canonicalSource);
475  if (!maybeSourceDeallocOp.has_value())
476  return failure();
477  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
478  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
479 
480  // If both are deallocated in the same block, their in-block lifetimes
481  // might not fully overlap, so we cannot decide which one to drop.
482  if (cloneDeallocOp && sourceDeallocOp &&
483  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
484  return failure();
485 
486  Block *currentBlock = cloneOp->getBlock();
487  Operation *redundantDealloc = nullptr;
488  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
489  redundantDealloc = cloneDeallocOp;
490  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
491  redundantDealloc = sourceDeallocOp;
492  }
493 
494  if (!redundantDealloc)
495  return failure();
496 
497  // Safety check that there are no other deallocations inbetween
498  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
499  // of source before the uses of the clone. With alias information, we could
500  // restrict this to only fail of the dealloc's operand is an alias
501  // of the source.
502  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
503  pos = pos->getNextNode()) {
504  // Bail if we run out of operations while looking for a deallocation op.
505  if (!pos)
506  return failure();
507  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
508  if (!effectInterface)
509  continue;
510  if (effectInterface.hasEffect<MemoryEffects::Free>())
511  return failure();
512  }
513 
514  if (source.getType() != cloneOp.getType())
515  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
516  cloneOp.getType(), source);
517  rewriter.replaceOp(cloneOp, source);
518  rewriter.eraseOp(redundantDealloc);
519  return success();
520  }
521 };
522 
523 } // namespace
524 
525 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
526  MLIRContext *context) {
527  results.add<SimplifyClones>(context);
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // DeallocTensorOp
532 //===----------------------------------------------------------------------===//
533 
534 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
535  const BufferizationOptions &options) {
536  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
537  if (failed(buffer))
538  return failure();
539  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
540  rewriter.eraseOp(getOperation());
541  return success();
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // MaterializeInDestinationOp
546 //===----------------------------------------------------------------------===//
547 
548 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
549  OpOperand &opOperand, const AnalysisState &state) {
550  return opOperand == getSourceMutable();
551 }
552 
553 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
554  OpOperand &opOperand, const AnalysisState &state) {
555  if (opOperand == getDestMutable()) {
556  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
557  return true;
558  }
559  return false;
560 }
561 
562 bool MaterializeInDestinationOp::mustBufferizeInPlace(
563  OpOperand &opOperand, const AnalysisState &state) {
564  // The source is only read and not written, so it always bufferizes in-place
565  // by default. The destination is written and is forced to bufferize in-place
566  // (if it is a tensor).
567  return true;
568 }
569 
571 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
572  const AnalysisState &state) {
573  if (opOperand == getDestMutable()) {
574  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
575  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
576  }
577  return {};
578 }
579 
580 LogicalResult
581 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
582  const BufferizationOptions &options) {
583  bool tensorDest = isa<TensorType>(getDest().getType());
584  Value buffer;
585  if (tensorDest) {
586  FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
587  if (failed(maybeBuffer))
588  return failure();
589  buffer = *maybeBuffer;
590  } else {
591  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
592  buffer = getDest();
593  }
594  auto srcBuffer = getBuffer(rewriter, getSource(), options);
595  if (failed(srcBuffer))
596  return failure();
597  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
598  return failure();
599  replaceOpWithBufferizedValues(rewriter, getOperation(),
600  tensorDest ? ValueRange(buffer) : ValueRange());
601  return success();
602 }
603 
604 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
605  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
606  // As elements are copied from the "source" buffer to the "dest" buffer,
607  // already copied elements are not read a second time.
608  return true;
609 }
610 
612  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
613  if (getOperation()->getNumResults() == 1) {
614  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
615  reifiedReturnShapes.resize(1,
616  SmallVector<OpFoldResult>(getType().getRank()));
617  reifiedReturnShapes[0] =
618  tensor::getMixedSizes(builder, getLoc(), getDest());
619  }
620  return success();
621 }
622 
623 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
624  Location loc) {
625  if (isa<TensorType>(getDest().getType())) {
626  // The subset is the entire destination tensor.
627  return getDest();
628  }
629 
630  // The "restrict" attribute is transferred from this op to the newly created
631  // to_tensor op. If this op does not the "restrict" attribute, the subset
632  // extraction cannot be built because there is no guarantee that there is no
633  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
634  if (!getRestrict())
635  return {};
636 
637  // Build a bufferization.to_tensor op.
638  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
639  assert(getRestrict() &&
640  "expected that ops with memrefs dest have 'restrict'");
641  setRestrict(false);
642  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
643  getWritable());
644 }
645 
646 bool MaterializeInDestinationOp::isEquivalentSubset(
647  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
648  return equivalenceFn(getDest(), candidate);
649 }
650 
652 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
653  return {getDest()};
654 }
655 
656 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
657  return getOperation()->getOpOperand(0) /*source*/;
658 }
659 
660 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
661  SubsetOpInterface subsetOp,
662  function_ref<bool(Value, Value)> equivalenceFn) {
663  return false;
664 }
665 
666 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
667  SubsetOpInterface subsetOp,
668  function_ref<bool(Value, Value)> equivalenceFn) {
669  return false;
670 }
671 
672 LogicalResult MaterializeInDestinationOp::verify() {
673  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
674  return emitOpError("'dest' must be a tensor or a memref");
675  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
676  if (getOperation()->getNumResults() != 1)
677  return emitOpError("tensor 'dest' implies exactly one tensor result");
678  if (destType != getResult().getType())
679  return emitOpError("result and 'dest' types must match");
680  }
681  if (isa<BaseMemRefType>(getDest().getType()) &&
682  getOperation()->getNumResults() != 0)
683  return emitOpError("memref 'dest' implies zero results");
684  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
685  return emitOpError("'restrict' is valid only for memref destinations");
686  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
687  return emitOpError("'writable' must be specified if and only if the "
688  "destination is of memref type");
689  TensorType srcType = getSource().getType();
690  ShapedType destType = cast<ShapedType>(getDest().getType());
691  if (srcType.hasRank() != destType.hasRank())
692  return emitOpError("source/destination shapes are incompatible");
693  if (srcType.hasRank()) {
694  if (srcType.getRank() != destType.getRank())
695  return emitOpError("rank mismatch between source and destination shape");
696  for (auto [src, dest] :
697  llvm::zip(srcType.getShape(), destType.getShape())) {
698  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
699  // Cannot verify dynamic dimension size. Assume that that they match at
700  // runtime.
701  continue;
702  }
703  if (src != dest)
704  return emitOpError("source/destination shapes are incompatible");
705  }
706  }
707  return success();
708 }
709 
710 void MaterializeInDestinationOp::build(OpBuilder &builder,
711  OperationState &state, Value source,
712  Value dest) {
713  auto destTensorType = dyn_cast<TensorType>(dest.getType());
714  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
715  source, dest);
716 }
717 
718 bool MaterializeInDestinationOp::isWritable(Value value,
719  const AnalysisState &state) {
720  return isa<TensorType>(getDest().getType()) ? true : getWritable();
721 }
722 
723 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
724  return getDestMutable();
725 }
726 
727 void MaterializeInDestinationOp::getEffects(
729  &effects) {
730  if (isa<BaseMemRefType>(getDest().getType()))
731  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
733 }
734 
735 //===----------------------------------------------------------------------===//
736 // ToTensorOp
737 //===----------------------------------------------------------------------===//
738 
739 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
740  return getWritable();
741 }
742 
743 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
744  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
745  // Approximate alias analysis by conservatively folding only when no there
746  // is no interleaved operation.
747  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
748  toMemref->getNextNode() == this->getOperation())
749  return toMemref.getTensor();
750  return {};
751 }
752 
753 namespace {
754 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
756 
757  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
758  PatternRewriter &rewriter) const override {
759  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
760  if (!memrefToTensorOp)
761  return failure();
762 
763  rewriter.replaceOpWithNewOp<memref::DimOp>(
764  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
765  return success();
766  }
767 };
768 } // namespace
769 
770 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
771  MLIRContext *context) {
772  results.add<DimOfToTensorFolder>(context);
773 }
774 
775 //===----------------------------------------------------------------------===//
776 // ToMemrefOp
777 //===----------------------------------------------------------------------===//
778 
779 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
780  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
781  if (memrefToTensor.getMemref().getType() == getType())
782  return memrefToTensor.getMemref();
783  return {};
784 }
785 
786 namespace {
787 
788 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
789 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
791 
792  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
793  PatternRewriter &rewriter) const final {
794  auto tensorCastOperand =
795  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
796  if (!tensorCastOperand)
797  return failure();
798  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
799  tensorCastOperand.getOperand().getType());
800  if (!srcTensorType)
801  return failure();
802  auto memrefType = MemRefType::get(srcTensorType.getShape(),
803  srcTensorType.getElementType());
804  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
805  tensorCastOperand.getOperand());
806  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
807  memref);
808  return success();
809  }
810 };
811 
812 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
813 /// cast if necessary.
814 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
816 
817  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
818  PatternRewriter &rewriter) const final {
820  options.bufferAlignment = 0;
821  return foldToMemrefToTensorPair(rewriter, toMemref, options);
822  }
823 };
824 
825 /// Fold a load on a to_memref operation into an tensor.extract on the
826 /// corresponding tensor.
827 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
829 
830  LogicalResult matchAndRewrite(memref::LoadOp load,
831  PatternRewriter &rewriter) const override {
832  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
833  if (!toMemref)
834  return failure();
835 
836  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
837  load.getIndices());
838  return success();
839  }
840 };
841 
842 /// Fold dim of a to_memref into the dim of the tensor.
843 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
845 
846  LogicalResult matchAndRewrite(memref::DimOp dimOp,
847  PatternRewriter &rewriter) const override {
848  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
849  if (!castOp)
850  return failure();
851  Value newSource = castOp.getOperand();
852  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
853  dimOp.getIndex());
854  return success();
855  }
856 };
857 
858 } // namespace
859 
860 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
861  MLIRContext *context) {
862  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
863  ToMemrefToTensorFolding>(context);
864 }
865 
866 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
867  const BufferizationOptions &options) {
868  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
869  (void)foldToMemrefToTensorPair(rewriter, *this, options);
870  // Note: The return value of `bufferize` indicates whether there was an error
871  // or not. (And not whether the pattern matched or not.)
872  return success();
873 }
874 
875 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
876  Value alloc) {
877  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
878  .getOperation();
879 }
880 
881 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
882  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // DeallocOp
887 //===----------------------------------------------------------------------===//
888 
889 LogicalResult DeallocOp::inferReturnTypes(
890  MLIRContext *context, std::optional<::mlir::Location> location,
891  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
892  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
893  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
894  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
895  IntegerType::get(context, 1));
896  return success();
897 }
898 
899 LogicalResult DeallocOp::verify() {
900  if (getMemrefs().size() != getConditions().size())
901  return emitOpError(
902  "must have the same number of conditions as memrefs to deallocate");
903  if (getRetained().size() != getUpdatedConditions().size())
904  return emitOpError("must have the same number of updated conditions "
905  "(results) as retained operands");
906  return success();
907 }
908 
909 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
910  ValueRange memrefs,
911  ValueRange conditions,
912  PatternRewriter &rewriter) {
913  if (deallocOp.getMemrefs() == memrefs &&
914  deallocOp.getConditions() == conditions)
915  return failure();
916 
917  rewriter.modifyOpInPlace(deallocOp, [&]() {
918  deallocOp.getMemrefsMutable().assign(memrefs);
919  deallocOp.getConditionsMutable().assign(conditions);
920  });
921  return success();
922 }
923 
924 namespace {
925 
926 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
927 /// make sure the corresponding condition value is updated accordingly since
928 /// their two conditions might not cover the same set of cases. In that case, we
929 /// have to combine them (by computing the disjunction of them).
930 /// Example:
931 /// ```mlir
932 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
933 /// ```
934 /// is canonicalized to
935 /// ```mlir
936 /// %0 = arith.ori %arg1, %arg2 : i1
937 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
938 /// ```
939 struct DeallocRemoveDuplicateDeallocMemrefs
940  : public OpRewritePattern<DeallocOp> {
942 
943  LogicalResult matchAndRewrite(DeallocOp deallocOp,
944  PatternRewriter &rewriter) const override {
945  // Unique memrefs to be deallocated.
946  DenseMap<Value, unsigned> memrefToCondition;
947  SmallVector<Value> newMemrefs, newConditions;
948  for (auto [i, memref, cond] :
949  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
950  if (memrefToCondition.count(memref)) {
951  // If the dealloc conditions don't match, we need to make sure that the
952  // dealloc happens on the union of cases.
953  Value &newCond = newConditions[memrefToCondition[memref]];
954  if (newCond != cond)
955  newCond =
956  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
957  } else {
958  memrefToCondition.insert({memref, newConditions.size()});
959  newMemrefs.push_back(memref);
960  newConditions.push_back(cond);
961  }
962  }
963 
964  // Return failure if we don't change anything such that we don't run into an
965  // infinite loop of pattern applications.
966  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
967  rewriter);
968  }
969 };
970 
971 /// Remove duplicate values in the list of retained memrefs. We need to make
972 /// sure the corresponding result condition value is replaced properly.
973 /// Example:
974 /// ```mlir
975 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
976 /// ```
977 /// is canonicalized to
978 /// ```mlir
979 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
980 /// ```
981 struct DeallocRemoveDuplicateRetainedMemrefs
982  : public OpRewritePattern<DeallocOp> {
984 
985  LogicalResult matchAndRewrite(DeallocOp deallocOp,
986  PatternRewriter &rewriter) const override {
987  // Unique retained values
989  SmallVector<Value> newRetained;
990  SmallVector<unsigned> resultReplacementIdx;
991  unsigned i = 0;
992  for (auto retained : deallocOp.getRetained()) {
993  if (seen.count(retained)) {
994  resultReplacementIdx.push_back(seen[retained]);
995  continue;
996  }
997 
998  seen[retained] = i;
999  newRetained.push_back(retained);
1000  resultReplacementIdx.push_back(i++);
1001  }
1002 
1003  // Return failure if we don't change anything such that we don't run into an
1004  // infinite loop of pattern applications.
1005  if (newRetained.size() == deallocOp.getRetained().size())
1006  return failure();
1007 
1008  // We need to create a new op because the number of results is always the
1009  // same as the number of condition operands.
1010  auto newDeallocOp =
1011  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1012  deallocOp.getConditions(), newRetained);
1013  SmallVector<Value> replacements(
1014  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1015  return newDeallocOp.getUpdatedConditions()[idx];
1016  }));
1017  rewriter.replaceOp(deallocOp, replacements);
1018  return success();
1019  }
1020 };
1021 
1022 /// Erase deallocation operations where the variadic list of memrefs to
1023 /// deallocate is empty. Example:
1024 /// ```mlir
1025 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1026 /// ```
1027 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1029 
1030  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1031  PatternRewriter &rewriter) const override {
1032  if (deallocOp.getMemrefs().empty()) {
1033  Value constFalse = rewriter.create<arith::ConstantOp>(
1034  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1035  rewriter.replaceOp(
1036  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1037  constFalse));
1038  return success();
1039  }
1040  return failure();
1041  }
1042 };
1043 
1044 /// Removes memrefs from the deallocation list if their associated condition is
1045 /// always 'false'.
1046 ///
1047 /// Example:
1048 /// ```
1049 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1050 /// if (%arg2, %false)
1051 /// ```
1052 /// becomes
1053 /// ```
1054 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1055 /// ```
1056 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1058 
1059  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1060  PatternRewriter &rewriter) const override {
1061  SmallVector<Value> newMemrefs, newConditions;
1062  for (auto [memref, cond] :
1063  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1064  if (!matchPattern(cond, m_Zero())) {
1065  newMemrefs.push_back(memref);
1066  newConditions.push_back(cond);
1067  }
1068  }
1069 
1070  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1071  rewriter);
1072  }
1073 };
1074 
1075 /// The `memref.extract_strided_metadata` is often inserted to get the base
1076 /// memref if the operand is not already guaranteed to be the result of a memref
1077 /// allocation operation. This canonicalization pattern removes this extraction
1078 /// operation if the operand is now produced by an allocation operation (e.g.,
1079 /// due to other canonicalizations simplifying the IR).
1080 ///
1081 /// Example:
1082 /// ```mlir
1083 /// %alloc = memref.alloc() : memref<2xi32>
1084 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1085 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1086 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1087 /// ```
1088 /// is canonicalized to
1089 /// ```mlir
1090 /// %alloc = memref.alloc() : memref<2xi32>
1091 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1092 /// ```
1093 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1095 
1096  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1097  PatternRewriter &rewriter) const override {
1098  SmallVector<Value> newMemrefs(
1099  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1100  auto extractStridedOp =
1101  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1102  if (!extractStridedOp)
1103  return memref;
1104  Value allocMemref = extractStridedOp.getOperand();
1105  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1106  if (!allocOp)
1107  return memref;
1108  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1109  return allocMemref;
1110  return memref;
1111  }));
1112 
1113  return updateDeallocIfChanged(deallocOp, newMemrefs,
1114  deallocOp.getConditions(), rewriter);
1115  }
1116 };
1117 
1118 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1119 /// other user of the allocated value and the allocating operation can be safely
1120 /// removed. If the same value is present multiple times, this pattern relies on
1121 /// other canonicalization patterns to remove the duplicate first.
1122 ///
1123 /// Example:
1124 /// ```mlir
1125 /// %alloc = memref.alloc() : memref<2xi32>
1126 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1127 /// ```
1128 /// is canonicalized to
1129 /// ```mlir
1130 /// bufferization.dealloc (%arg0 : ...) if (%true)
1131 /// ```
1132 struct RemoveAllocDeallocPairWhenNoOtherUsers
1133  : public OpRewritePattern<DeallocOp> {
1135 
1136  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1137  PatternRewriter &rewriter) const override {
1138  SmallVector<Value> newMemrefs, newConditions;
1139  SmallVector<Operation *> toDelete;
1140  for (auto [memref, cond] :
1141  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1142  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1143  // Check that it is indeed an allocate effect, that the op has no other
1144  // side effects (which would not allow us to remove the op), and that
1145  // there are no other users.
1146  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1147  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1148  memref.hasOneUse()) {
1149  toDelete.push_back(allocOp);
1150  continue;
1151  }
1152  }
1153 
1154  newMemrefs.push_back(memref);
1155  newConditions.push_back(cond);
1156  }
1157 
1158  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1159  rewriter)))
1160  return failure();
1161 
1162  for (Operation *op : toDelete)
1163  rewriter.eraseOp(op);
1164 
1165  return success();
1166  }
1167 };
1168 
1169 } // anonymous namespace
1170 
1171 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1172  MLIRContext *context) {
1174 }
1175 
1177  RewritePatternSet &patterns, MLIRContext *context) {
1178  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1179  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1180  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1181  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1182 }
1183 
1184 //===----------------------------------------------------------------------===//
1185 // TableGen'd op method definitions
1186 //===----------------------------------------------------------------------===//
1187 
1188 #define GET_OP_CLASSES
1189 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:96
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.