MLIR  22.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/Matchers.h"
16 #include <optional>
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 
21 //===----------------------------------------------------------------------===//
22 // Helper functions
23 //===----------------------------------------------------------------------===//
24 
26  OpBuilder &b, Value value, MemRefType destType,
28  auto srcType = llvm::cast<MemRefType>(value.getType());
29 
30  // Element type and rank must match.
31  if (srcType.getElementType() != destType.getElementType())
32  return failure();
33  if (srcType.getRank() != destType.getRank())
34  return failure();
35 
36  // In case the affine maps are different, we may need to use a copy if we go
37  // from dynamic to static offset or stride (the canonicalization cannot know
38  // at this point that it is really cast compatible).
39  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
40  int64_t sourceOffset, targetOffset;
41  SmallVector<int64_t, 4> sourceStrides, targetStrides;
42  if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
43  failed(target.getStridesAndOffset(targetStrides, targetOffset)))
44  return false;
45  auto dynamicToStatic = [](int64_t a, int64_t b) {
46  return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
47  };
48  if (dynamicToStatic(sourceOffset, targetOffset))
49  return false;
50  for (auto it : zip(sourceStrides, targetStrides))
51  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
52  return false;
53  return true;
54  };
55 
56  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
57  // ensure that we only generate casts that always succeed at runtime, we check
58  // a fix extra conditions in `isGuaranteedCastCompatible`.
59  if (memref::CastOp::areCastCompatible(srcType, destType) &&
60  isGuaranteedCastCompatible(srcType, destType)) {
61  Value casted = memref::CastOp::create(b, value.getLoc(), destType, value);
62  return casted;
63  }
64 
65  auto loc = value.getLoc();
66  SmallVector<Value, 4> dynamicOperands;
67  for (int i = 0; i < destType.getRank(); ++i) {
68  if (destType.getShape()[i] != ShapedType::kDynamic)
69  continue;
70  Value size = memref::DimOp::create(b, loc, value, i);
71  dynamicOperands.push_back(size);
72  }
73 
74  FailureOr<Value> copy =
75  options.createAlloc(b, loc, destType, dynamicOperands);
76  if (failed(copy))
77  return failure();
78  if (failed(options.createMemCpy(b, loc, value, *copy)))
79  return failure();
80  return copy;
81 }
82 
83 /// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
84 /// to_buffer op are different, a memref.cast is needed.
86  RewriterBase &rewriter, ToBufferOp toBuffer,
88  auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
89  if (!bufferToTensor)
90  return failure();
91 
92  Type srcType = bufferToTensor.getBuffer().getType();
93  Type destType = toBuffer.getType();
94 
95  // Directly rewrite if the type did not change.
96  if (srcType == destType) {
97  rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
98  return success();
99  }
100 
101  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
102  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
103  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
104 
105  // Ranked memref -> Ranked memref cast.
106  if (rankedSrcType && rankedDestType) {
107  FailureOr<Value> replacement = castOrReallocMemRefValue(
108  rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
109  if (failed(replacement))
110  return failure();
111 
112  rewriter.replaceOp(toBuffer, *replacement);
113  return success();
114  }
115 
116  // Unranked memref -> Ranked memref cast: May require a copy.
117  // TODO: Not implemented at the moment.
118  if (unrankedSrcType && rankedDestType)
119  return failure();
120 
121  // Unranked memref -> unranked memref cast
122  // Ranked memref -> unranked memref cast: No copy needed.
123  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
124  "expected that types are cast compatible");
125  rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
126  bufferToTensor.getBuffer());
127  return success();
128 }
129 
131  OpBuilder &b, Location loc, Value shapedValue,
132  SmallVector<Value> &dynamicDims) {
133  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
134  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
135  if (shapedType.isDynamicDim(i)) {
136  if (llvm::isa<MemRefType>(shapedType)) {
137  dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i));
138  } else {
139  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
140  dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i));
141  }
142  }
143  }
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // AllocTensorOp
148 //===----------------------------------------------------------------------===//
149 
150 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152  BufferizationState &state) {
153  OpBuilder::InsertionGuard g(rewriter);
154  Location loc = getLoc();
155 
156  // Nothing to do for dead AllocTensorOps.
157  if (getOperation()->getUses().empty()) {
158  rewriter.eraseOp(getOperation());
159  return success();
160  }
161 
162  // Get "copy" buffer.
163  Value copyBuffer;
164  if (getCopy()) {
165  FailureOr<Value> maybeCopyBuffer =
166  getBuffer(rewriter, getCopy(), options, state);
167  if (failed(maybeCopyBuffer))
168  return failure();
169  copyBuffer = *maybeCopyBuffer;
170  }
171 
172  // Create memory allocation.
173  auto allocType = bufferization::getBufferType(getResult(), options, state);
174  if (failed(allocType))
175  return failure();
176  SmallVector<Value> dynamicDims = getDynamicSizes();
177  if (getCopy()) {
178  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
179  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
180  }
181  FailureOr<Value> alloc = options.createAlloc(
182  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
183  if (failed(alloc))
184  return failure();
185 
186  // Create memory copy (if any).
187  if (getCopy()) {
188  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
189  return failure();
190  }
191 
192  // Replace op.
193  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
194 
195  return success();
196 }
197 
198 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
199  const AnalysisState &state) {
200  // AllocTensorOps do not write unless they have a `copy` value.
201  return static_cast<bool>(getCopy());
202 }
203 
204 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
205  const AnalysisState &state) {
206  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
207  "expected copy operand");
208  return true;
209 }
210 
211 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
212  const AnalysisState &state) {
213  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
214  "expected copy operand");
215  return false;
216 }
217 
218 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
219  const AnalysisState &state) {
220  // This is a new allocation. It does not alias with any other buffer.
221  return {};
222 }
223 
224 FailureOr<BufferLikeType>
226  const BufferizationState &state,
227  SmallVector<Value> &invocationStack) {
228  assert(value == getResult() && "invalid value");
229 
230  // Compute memory space of this allocation.
231  Attribute memorySpace;
232  if (getMemorySpace().has_value()) {
233  memorySpace = *getMemorySpace();
234  } else if (getCopy()) {
235  auto copyBufferType =
237  getCopy(), options, state, invocationStack));
238  if (failed(copyBufferType))
239  return failure();
240  memorySpace = copyBufferType->getMemorySpace();
241  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
242  memorySpace = *ms;
243  } else {
244  return getOperation()->emitError("could not infer memory space");
245  }
246 
247  return cast<BufferLikeType>(
249 }
250 
251 LogicalResult AllocTensorOp::verify() {
252  if (getCopy() && !getDynamicSizes().empty())
253  return emitError("dynamic sizes not needed when copying a tensor");
254  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
255  return emitError("expected ")
256  << getType().getNumDynamicDims() << " dynamic sizes";
257  if (getCopy() && getCopy().getType() != getType())
258  return emitError("expected that `copy` and return type match");
259  return success();
260 }
261 
262 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
263  RankedTensorType type, ValueRange dynamicSizes) {
264  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
265  /*size_hint=*/Value(),
266  /*memory_space=*/IntegerAttr());
267 }
268 
269 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
270  RankedTensorType type, ValueRange dynamicSizes,
271  Value copy) {
272  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
273  /*memory_space=*/IntegerAttr());
274 }
275 
276 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
277  TensorType type, ValueRange dynamicSizes, Value copy,
278  IntegerAttr memorySpace) {
279  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
280  memorySpace);
281 }
282 
283 namespace {
284 /// Change the type of the result of a `bufferization.alloc_tensor` by making
285 /// the result type statically sized along dimension that in the original
286 /// operation where defined as dynamic, but the size was defined using a
287 /// `constant` op. For example:
288 ///
289 /// %c5 = arith.constant 5: index
290 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
291 ///
292 /// to
293 ///
294 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
295 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
297 
298  LogicalResult matchAndRewrite(AllocTensorOp op,
299  PatternRewriter &rewriter) const override {
300  if (op.getCopy())
301  return failure();
302  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
303  SmallVector<Value> newDynamicSizes;
304  unsigned int dynValCounter = 0;
305  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
306  if (!op.isDynamicDim(i))
307  continue;
308  Value value = op.getDynamicSizes()[dynValCounter++];
309  APInt intVal;
310  if (matchPattern(value, m_ConstantInt(&intVal))) {
311  int64_t dim = intVal.getSExtValue();
312  if (dim >= 0)
313  newShape[i] = intVal.getSExtValue();
314  else
315  newDynamicSizes.push_back(value);
316  } else {
317  newDynamicSizes.push_back(value);
318  }
319  }
320  RankedTensorType newType = RankedTensorType::get(
321  newShape, op.getType().getElementType(), op.getType().getEncoding());
322  if (newType == op.getType())
323  return failure();
324  auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
325  newDynamicSizes, /*copy=*/Value());
326  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
327  return success();
328  }
329 };
330 
331 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
333 
334  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
335  PatternRewriter &rewriter) const override {
336  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
337  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
338  if (!allocTensorOp || !maybeConstantIndex)
339  return failure();
340  if (*maybeConstantIndex < 0 ||
341  *maybeConstantIndex >= allocTensorOp.getType().getRank())
342  return failure();
343  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
344  return failure();
345  rewriter.replaceOp(
346  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
347  return success();
348  }
349 };
350 } // namespace
351 
352 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
353  MLIRContext *ctx) {
354  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
355 }
356 
358  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
359  auto shapes = llvm::to_vector<4>(
360  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
361  [&](int64_t dim) -> OpFoldResult {
362  if (isDynamicDim(dim))
363  return getDynamicSize(builder, dim);
364  return builder.getIndexAttr(getStaticSize(dim));
365  }));
366  reifiedReturnShapes.emplace_back(std::move(shapes));
367  return success();
368 }
369 
370 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
371  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
372  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
373  parser.parseRParen())
374  return failure();
375  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
376  OpAsmParser::UnresolvedOperand copyOperand;
377  if (copyKeyword.succeeded())
378  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
379  parser.parseRParen())
380  return failure();
381  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
382  OpAsmParser::UnresolvedOperand sizeHintOperand;
383  if (sizeHintKeyword.succeeded())
384  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
385  return failure();
386  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
387  return failure();
388 
389  TensorType type;
390  if (parser.parseCustomTypeWithFallback(type))
391  return failure();
392  result.addTypes(type);
393 
394  Type indexType = parser.getBuilder().getIndexType();
395  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
396  return failure();
397  if (copyKeyword.succeeded())
398  if (parser.resolveOperand(copyOperand, type, result.operands))
399  return failure();
400  if (sizeHintKeyword.succeeded())
401  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
402  return failure();
403  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
405  {static_cast<int32_t>(dynamicSizesOperands.size()),
406  static_cast<int32_t>(copyKeyword.succeeded()),
407  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
408  return success();
409 }
410 
412  p << "(" << getDynamicSizes() << ")";
413  if (getCopy())
414  p << " copy(" << getCopy() << ")";
415  if (getSizeHint())
416  p << " size_hint=" << getSizeHint();
417  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
418  AllocTensorOp::getOperandSegmentSizeAttr()});
419  p << " : ";
420  auto type = getResult().getType();
421  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
422  p.printStrippedAttrOrType(validType);
423  else
424  p << type;
425 }
426 
427 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
428  assert(isDynamicDim(idx) && "expected dynamic dim");
429  if (getCopy())
430  return tensor::DimOp::create(b, getLoc(), getCopy(), idx);
431  return getOperand(getIndexOfDynamicSize(idx));
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // CloneOp
436 //===----------------------------------------------------------------------===//
437 
438 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
439  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
440 }
441 
442 namespace {
443 
444 /// Merge the clone and its source (by converting the clone to a cast) when
445 /// possible.
446 struct SimplifyClones : public OpRewritePattern<CloneOp> {
448 
449  LogicalResult matchAndRewrite(CloneOp cloneOp,
450  PatternRewriter &rewriter) const override {
451  if (cloneOp.use_empty()) {
452  rewriter.eraseOp(cloneOp);
453  return success();
454  }
455 
456  Value source = cloneOp.getInput();
457  if (source.getType() != cloneOp.getType() &&
458  !memref::CastOp::areCastCompatible({source.getType()},
459  {cloneOp.getType()}))
460  return failure();
461 
462  // Aims to find the dealloc op for the canonical source
463  // which otherwise could prevent removal of unnecessary allocs.
464  Value canonicalSource = source;
465  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
466  canonicalSource.getDefiningOp())) {
467  if (canonicalSource != iface.getViewDest()) {
468  break;
469  }
470  canonicalSource = iface.getViewSource();
471  }
472 
473  std::optional<Operation *> maybeCloneDeallocOp =
474  memref::findDealloc(cloneOp.getOutput());
475  // Skip if either of them has > 1 deallocate operations.
476  if (!maybeCloneDeallocOp.has_value())
477  return failure();
478  std::optional<Operation *> maybeSourceDeallocOp =
479  memref::findDealloc(canonicalSource);
480  if (!maybeSourceDeallocOp.has_value())
481  return failure();
482  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
483  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
484 
485  // If both are deallocated in the same block, their in-block lifetimes
486  // might not fully overlap, so we cannot decide which one to drop.
487  if (cloneDeallocOp && sourceDeallocOp &&
488  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
489  return failure();
490 
491  Block *currentBlock = cloneOp->getBlock();
492  Operation *redundantDealloc = nullptr;
493  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
494  redundantDealloc = cloneDeallocOp;
495  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
496  redundantDealloc = sourceDeallocOp;
497  }
498 
499  if (!redundantDealloc)
500  return failure();
501 
502  // Safety check that there are no other deallocations inbetween
503  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
504  // of source before the uses of the clone. With alias information, we could
505  // restrict this to only fail of the dealloc's operand is an alias
506  // of the source.
507  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
508  pos = pos->getNextNode()) {
509  // Bail if we run out of operations while looking for a deallocation op.
510  if (!pos)
511  return failure();
512  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
513  if (!effectInterface)
514  continue;
515  if (effectInterface.hasEffect<MemoryEffects::Free>())
516  return failure();
517  }
518 
519  if (source.getType() != cloneOp.getType())
520  source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
521  cloneOp.getType(), source);
522  rewriter.replaceOp(cloneOp, source);
523  rewriter.eraseOp(redundantDealloc);
524  return success();
525  }
526 };
527 
528 } // namespace
529 
530 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
531  MLIRContext *context) {
532  results.add<SimplifyClones>(context);
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // DeallocTensorOp
537 //===----------------------------------------------------------------------===//
538 
539 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
541  BufferizationState &state) {
542  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
543  if (failed(buffer))
544  return failure();
545  memref::DeallocOp::create(rewriter, getLoc(), *buffer);
546  rewriter.eraseOp(getOperation());
547  return success();
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // MaterializeInDestinationOp
552 //===----------------------------------------------------------------------===//
553 
554 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
555  OpOperand &opOperand, const AnalysisState &state) {
556  return opOperand == getSourceMutable();
557 }
558 
559 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
560  OpOperand &opOperand, const AnalysisState &state) {
561  if (opOperand == getDestMutable()) {
562  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
563  return true;
564  }
565  return false;
566 }
567 
568 bool MaterializeInDestinationOp::mustBufferizeInPlace(
569  OpOperand &opOperand, const AnalysisState &state) {
570  // The source is only read and not written, so it always bufferizes in-place
571  // by default. The destination is written and is forced to bufferize in-place
572  // (if it is a tensor).
573  return true;
574 }
575 
577 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
578  const AnalysisState &state) {
579  if (opOperand == getDestMutable()) {
580  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
581  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
582  }
583  return {};
584 }
585 
586 LogicalResult
587 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
589  BufferizationState &state) {
590  bool tensorDest = isa<TensorType>(getDest().getType());
591  Value buffer;
592  if (tensorDest) {
593  FailureOr<Value> maybeBuffer =
594  getBuffer(rewriter, getDest(), options, state);
595  if (failed(maybeBuffer))
596  return failure();
597  buffer = *maybeBuffer;
598  } else {
599  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
600  buffer = getDest();
601  }
602  auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
603  if (failed(srcBuffer))
604  return failure();
605  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
606  return failure();
607  replaceOpWithBufferizedValues(rewriter, getOperation(),
608  tensorDest ? ValueRange(buffer) : ValueRange());
609  return success();
610 }
611 
612 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
613  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
614  // As elements are copied from the "source" buffer to the "dest" buffer,
615  // already copied elements are not read a second time.
616  return true;
617 }
618 
620  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
621  if (getOperation()->getNumResults() == 1) {
622  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
623  reifiedReturnShapes.resize(1,
624  SmallVector<OpFoldResult>(getType().getRank()));
625  reifiedReturnShapes[0] =
626  tensor::getMixedSizes(builder, getLoc(), getDest());
627  }
628  return success();
629 }
630 
632  Location loc) {
633  if (isa<TensorType>(getDest().getType())) {
634  // The subset is the entire destination tensor.
635  return getDest();
636  }
637 
638  // The "restrict" attribute is transferred from this op to the newly created
639  // to_tensor op. If this op does not the "restrict" attribute, the subset
640  // extraction cannot be built because there is no guarantee that there is no
641  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
642  if (!getRestrict())
643  return {};
644 
645  // Build a bufferization.to_tensor op.
646  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
647  assert(getRestrict() &&
648  "expected that ops with memrefs dest have 'restrict'");
649  setRestrict(false);
650  return ToTensorOp::create(
651  builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()),
652  getDest(),
653  /*restrict=*/true, getWritable());
654 }
655 
656 bool MaterializeInDestinationOp::isEquivalentSubset(
657  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
658  return equivalenceFn(getDest(), candidate);
659 }
660 
662 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
663  return {getDest()};
664 }
665 
666 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
667  return getOperation()->getOpOperand(0) /*source*/;
668 }
669 
670 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
671  SubsetOpInterface subsetOp,
672  function_ref<bool(Value, Value)> equivalenceFn) {
673  return false;
674 }
675 
676 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
677  SubsetOpInterface subsetOp,
678  function_ref<bool(Value, Value)> equivalenceFn) {
679  return false;
680 }
681 
682 LogicalResult MaterializeInDestinationOp::verify() {
683  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
684  return emitOpError("'dest' must be a tensor or a memref");
685  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
686  if (getOperation()->getNumResults() != 1)
687  return emitOpError("tensor 'dest' implies exactly one tensor result");
688  if (destType != getResult().getType())
689  return emitOpError("result and 'dest' types must match");
690  }
691  if (isa<BaseMemRefType>(getDest().getType()) &&
692  getOperation()->getNumResults() != 0)
693  return emitOpError("memref 'dest' implies zero results");
694  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
695  return emitOpError("'restrict' is valid only for memref destinations");
696  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
697  return emitOpError("'writable' must be specified if and only if the "
698  "destination is of memref type");
699  TensorType srcType = getSource().getType();
700  ShapedType destType = cast<ShapedType>(getDest().getType());
701  if (srcType.hasRank() != destType.hasRank())
702  return emitOpError("source/destination shapes are incompatible");
703  if (srcType.hasRank()) {
704  if (srcType.getRank() != destType.getRank())
705  return emitOpError("rank mismatch between source and destination shape");
706  for (auto [src, dest] :
707  llvm::zip(srcType.getShape(), destType.getShape())) {
708  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
709  // Cannot verify dynamic dimension size. Assume that that they match at
710  // runtime.
711  continue;
712  }
713  if (src != dest)
714  return emitOpError("source/destination shapes are incompatible");
715  }
716  }
717  return success();
718 }
719 
720 void MaterializeInDestinationOp::build(OpBuilder &builder,
721  OperationState &state, Value source,
722  Value dest) {
723  auto destTensorType = dyn_cast<TensorType>(dest.getType());
724  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
725  source, dest);
726 }
727 
728 bool MaterializeInDestinationOp::isWritable(Value value,
729  const AnalysisState &state) {
730  return isa<TensorType>(getDest().getType()) ? true : getWritable();
731 }
732 
733 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
734  return getDestMutable();
735 }
736 
737 void MaterializeInDestinationOp::getEffects(
739  &effects) {
740  if (isa<BaseMemRefType>(getDest().getType()))
741  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // ToTensorOp
747 //===----------------------------------------------------------------------===//
748 
749 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
750  return getWritable();
751 }
752 
753 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
754  if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
755  // Approximate alias analysis by conservatively folding only when no there
756  // is no interleaved operation.
757  if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
758  toBuffer->getNextNode() == this->getOperation())
759  return toBuffer.getTensor();
760  return {};
761 }
762 
763 namespace {
764 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
766 
767  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
768  PatternRewriter &rewriter) const override {
769  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
770  if (!memrefToTensorOp)
771  return failure();
772 
773  rewriter.replaceOpWithNewOp<memref::DimOp>(
774  dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
775  return success();
776  }
777 };
778 } // namespace
779 
780 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
781  MLIRContext *context) {
782  results.add<DimOfToTensorFolder>(context);
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // ToBufferOp
787 //===----------------------------------------------------------------------===//
788 
789 OpFoldResult ToBufferOp::fold(FoldAdaptor) {
790  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
791  if (memrefToTensor.getBuffer().getType() == getType())
792  return memrefToTensor.getBuffer();
793  return {};
794 }
795 
796 namespace {
797 
798 /// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
799 struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
801 
802  LogicalResult matchAndRewrite(ToBufferOp toBuffer,
803  PatternRewriter &rewriter) const final {
804  auto tensorCastOperand =
805  toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
806  if (!tensorCastOperand)
807  return failure();
808  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
809  tensorCastOperand.getOperand().getType());
810  if (!srcTensorType)
811  return failure();
812  auto currentOutputMemRefType =
813  dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
814  if (!currentOutputMemRefType)
815  return failure();
816 
817  auto memrefType = currentOutputMemRefType.cloneWith(
818  srcTensorType.getShape(), srcTensorType.getElementType());
819  Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
820  tensorCastOperand.getOperand(),
821  toBuffer.getReadOnly());
822  rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
823  memref);
824  return success();
825  }
826 };
827 
828 /// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
829 /// cast if necessary.
830 struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
832 
833  LogicalResult matchAndRewrite(ToBufferOp toBuffer,
834  PatternRewriter &rewriter) const final {
836  options.bufferAlignment = 0;
837  return foldToBufferToTensorPair(rewriter, toBuffer, options);
838  }
839 };
840 
841 /// Fold a load on a to_buffer operation into an tensor.extract on the
842 /// corresponding tensor.
843 struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
845 
846  LogicalResult matchAndRewrite(memref::LoadOp load,
847  PatternRewriter &rewriter) const override {
848  auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
849  if (!toBuffer)
850  return failure();
851 
852  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
853  load.getIndices());
854  return success();
855  }
856 };
857 
858 /// Fold dim of a to_buffer into the dim of the tensor.
859 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
861 
862  LogicalResult matchAndRewrite(memref::DimOp dimOp,
863  PatternRewriter &rewriter) const override {
864  auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
865  if (!castOp)
866  return failure();
867  Value newSource = castOp.getOperand();
868  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
869  dimOp.getIndex());
870  return success();
871  }
872 };
873 
874 } // namespace
875 
876 void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
877  MLIRContext *context) {
878  results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
879  ToBufferToTensorFolding>(context);
880 }
881 
882 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
884  BufferizationState &state) {
885  // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
886  (void)foldToBufferToTensorPair(rewriter, *this, options);
887  // Note: The return value of `bufferize` indicates whether there was an error
888  // or not. (And not whether the pattern matched or not.)
889  return success();
890 }
891 
892 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
893  Value alloc) {
894  return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
895  .getOperation();
896 }
897 
898 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
899  return CloneOp::create(builder, alloc.getLoc(), alloc).getResult();
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // DeallocOp
904 //===----------------------------------------------------------------------===//
905 
906 LogicalResult DeallocOp::inferReturnTypes(
907  MLIRContext *context, std::optional<::mlir::Location> location,
908  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
909  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
910  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
911  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
912  IntegerType::get(context, 1));
913  return success();
914 }
915 
916 LogicalResult DeallocOp::verify() {
917  if (getMemrefs().size() != getConditions().size())
918  return emitOpError(
919  "must have the same number of conditions as memrefs to deallocate");
920  if (getRetained().size() != getUpdatedConditions().size())
921  return emitOpError("must have the same number of updated conditions "
922  "(results) as retained operands");
923  return success();
924 }
925 
926 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
927  ValueRange memrefs,
928  ValueRange conditions,
929  PatternRewriter &rewriter) {
930  if (deallocOp.getMemrefs() == memrefs &&
931  deallocOp.getConditions() == conditions)
932  return failure();
933 
934  rewriter.modifyOpInPlace(deallocOp, [&]() {
935  deallocOp.getMemrefsMutable().assign(memrefs);
936  deallocOp.getConditionsMutable().assign(conditions);
937  });
938  return success();
939 }
940 
941 namespace {
942 
943 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
944 /// make sure the corresponding condition value is updated accordingly since
945 /// their two conditions might not cover the same set of cases. In that case, we
946 /// have to combine them (by computing the disjunction of them).
947 /// Example:
948 /// ```mlir
949 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
950 /// ```
951 /// is canonicalized to
952 /// ```mlir
953 /// %0 = arith.ori %arg1, %arg2 : i1
954 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
955 /// ```
956 struct DeallocRemoveDuplicateDeallocMemrefs
957  : public OpRewritePattern<DeallocOp> {
959 
960  LogicalResult matchAndRewrite(DeallocOp deallocOp,
961  PatternRewriter &rewriter) const override {
962  // Unique memrefs to be deallocated.
963  DenseMap<Value, unsigned> memrefToCondition;
964  SmallVector<Value> newMemrefs, newConditions;
965  for (auto [i, memref, cond] :
966  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
967  if (memrefToCondition.count(memref)) {
968  // If the dealloc conditions don't match, we need to make sure that the
969  // dealloc happens on the union of cases.
970  Value &newCond = newConditions[memrefToCondition[memref]];
971  if (newCond != cond)
972  newCond =
973  arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
974  } else {
975  memrefToCondition.insert({memref, newConditions.size()});
976  newMemrefs.push_back(memref);
977  newConditions.push_back(cond);
978  }
979  }
980 
981  // Return failure if we don't change anything such that we don't run into an
982  // infinite loop of pattern applications.
983  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
984  rewriter);
985  }
986 };
987 
988 /// Remove duplicate values in the list of retained memrefs. We need to make
989 /// sure the corresponding result condition value is replaced properly.
990 /// Example:
991 /// ```mlir
992 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
993 /// ```
994 /// is canonicalized to
995 /// ```mlir
996 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
997 /// ```
998 struct DeallocRemoveDuplicateRetainedMemrefs
999  : public OpRewritePattern<DeallocOp> {
1001 
1002  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1003  PatternRewriter &rewriter) const override {
1004  // Unique retained values
1006  SmallVector<Value> newRetained;
1007  SmallVector<unsigned> resultReplacementIdx;
1008  unsigned i = 0;
1009  for (auto retained : deallocOp.getRetained()) {
1010  if (seen.count(retained)) {
1011  resultReplacementIdx.push_back(seen[retained]);
1012  continue;
1013  }
1014 
1015  seen[retained] = i;
1016  newRetained.push_back(retained);
1017  resultReplacementIdx.push_back(i++);
1018  }
1019 
1020  // Return failure if we don't change anything such that we don't run into an
1021  // infinite loop of pattern applications.
1022  if (newRetained.size() == deallocOp.getRetained().size())
1023  return failure();
1024 
1025  // We need to create a new op because the number of results is always the
1026  // same as the number of condition operands.
1027  auto newDeallocOp =
1028  DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1029  deallocOp.getConditions(), newRetained);
1030  SmallVector<Value> replacements(
1031  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1032  return newDeallocOp.getUpdatedConditions()[idx];
1033  }));
1034  rewriter.replaceOp(deallocOp, replacements);
1035  return success();
1036  }
1037 };
1038 
1039 /// Erase deallocation operations where the variadic list of memrefs to
1040 /// deallocate is empty. Example:
1041 /// ```mlir
1042 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1043 /// ```
1044 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1046 
1047  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1048  PatternRewriter &rewriter) const override {
1049  if (deallocOp.getMemrefs().empty()) {
1050  Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1051  rewriter.getBoolAttr(false));
1052  rewriter.replaceOp(
1053  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1054  constFalse));
1055  return success();
1056  }
1057  return failure();
1058  }
1059 };
1060 
1061 /// Removes memrefs from the deallocation list if their associated condition is
1062 /// always 'false'.
1063 ///
1064 /// Example:
1065 /// ```
1066 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1067 /// if (%arg2, %false)
1068 /// ```
1069 /// becomes
1070 /// ```
1071 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1072 /// ```
1073 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1075 
1076  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1077  PatternRewriter &rewriter) const override {
1078  SmallVector<Value> newMemrefs, newConditions;
1079  for (auto [memref, cond] :
1080  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1081  if (!matchPattern(cond, m_Zero())) {
1082  newMemrefs.push_back(memref);
1083  newConditions.push_back(cond);
1084  }
1085  }
1086 
1087  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1088  rewriter);
1089  }
1090 };
1091 
1092 /// The `memref.extract_strided_metadata` is often inserted to get the base
1093 /// memref if the operand is not already guaranteed to be the result of a memref
1094 /// allocation operation. This canonicalization pattern removes this extraction
1095 /// operation if the operand is now produced by an allocation operation (e.g.,
1096 /// due to other canonicalizations simplifying the IR).
1097 ///
1098 /// Example:
1099 /// ```mlir
1100 /// %alloc = memref.alloc() : memref<2xi32>
1101 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1102 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1103 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1104 /// ```
1105 /// is canonicalized to
1106 /// ```mlir
1107 /// %alloc = memref.alloc() : memref<2xi32>
1108 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1109 /// ```
1110 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1112 
1113  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1114  PatternRewriter &rewriter) const override {
1115  SmallVector<Value> newMemrefs(
1116  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1117  auto extractStridedOp =
1118  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1119  if (!extractStridedOp)
1120  return memref;
1121  Value allocMemref = extractStridedOp.getOperand();
1122  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1123  if (!allocOp)
1124  return memref;
1125  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1126  return allocMemref;
1127  return memref;
1128  }));
1129 
1130  return updateDeallocIfChanged(deallocOp, newMemrefs,
1131  deallocOp.getConditions(), rewriter);
1132  }
1133 };
1134 
1135 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1136 /// other user of the allocated value and the allocating operation can be safely
1137 /// removed. If the same value is present multiple times, this pattern relies on
1138 /// other canonicalization patterns to remove the duplicate first.
1139 ///
1140 /// Example:
1141 /// ```mlir
1142 /// %alloc = memref.alloc() : memref<2xi32>
1143 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1144 /// ```
1145 /// is canonicalized to
1146 /// ```mlir
1147 /// bufferization.dealloc (%arg0 : ...) if (%true)
1148 /// ```
1149 struct RemoveAllocDeallocPairWhenNoOtherUsers
1150  : public OpRewritePattern<DeallocOp> {
1152 
1153  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1154  PatternRewriter &rewriter) const override {
1155  SmallVector<Value> newMemrefs, newConditions;
1156  SmallVector<Operation *> toDelete;
1157  for (auto [memref, cond] :
1158  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1159  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1160  // Check that it is indeed an allocate effect, that the op has no other
1161  // side effects (which would not allow us to remove the op), and that
1162  // there are no other users.
1163  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1164  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1165  memref.hasOneUse()) {
1166  toDelete.push_back(allocOp);
1167  continue;
1168  }
1169  }
1170 
1171  newMemrefs.push_back(memref);
1172  newConditions.push_back(cond);
1173  }
1174 
1175  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1176  rewriter)))
1177  return failure();
1178 
1179  for (Operation *op : toDelete)
1180  rewriter.eraseOp(op);
1181 
1182  return success();
1183  }
1184 };
1185 
1186 } // anonymous namespace
1187 
1188 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1189  MLIRContext *context) {
1191 }
1192 
1194  RewritePatternSet &patterns, MLIRContext *context) {
1195  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1196  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1197  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1198  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // TableGen'd op method definitions
1203 //===----------------------------------------------------------------------===//
1204 
1205 #define GET_OP_CLASSES
1206 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.