MLIR  20.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type, rank and memory space must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getMemorySpace() != destType.getMemorySpace())
35  return failure();
36  if (srcType.getRank() != destType.getRank())
37  return failure();
38 
39  // In case the affine maps are different, we may need to use a copy if we go
40  // from dynamic to static offset or stride (the canonicalization cannot know
41  // at this point that it is really cast compatible).
42  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43  int64_t sourceOffset, targetOffset;
44  SmallVector<int64_t, 4> sourceStrides, targetStrides;
45  if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46  failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47  return false;
48  auto dynamicToStatic = [](int64_t a, int64_t b) {
49  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50  };
51  if (dynamicToStatic(sourceOffset, targetOffset))
52  return false;
53  for (auto it : zip(sourceStrides, targetStrides))
54  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
55  return false;
56  return true;
57  };
58 
59  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60  // ensure that we only generate casts that always succeed at runtime, we check
61  // a fix extra conditions in `isGuaranteedCastCompatible`.
62  if (memref::CastOp::areCastCompatible(srcType, destType) &&
63  isGuaranteedCastCompatible(srcType, destType)) {
64  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65  return casted;
66  }
67 
68  auto loc = value.getLoc();
69  SmallVector<Value, 4> dynamicOperands;
70  for (int i = 0; i < destType.getRank(); ++i) {
71  if (destType.getShape()[i] != ShapedType::kDynamic)
72  continue;
73  Value size = b.create<memref::DimOp>(loc, value, i);
74  dynamicOperands.push_back(size);
75  }
76 
77  FailureOr<Value> copy =
78  options.createAlloc(b, loc, destType, dynamicOperands);
79  if (failed(copy))
80  return failure();
81  if (failed(options.createMemCpy(b, loc, value, *copy)))
82  return failure();
83  return copy;
84 }
85 
86 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87 /// to_memref op are different, a memref.cast is needed.
89  RewriterBase &rewriter, ToMemrefOp toMemref,
91  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92  if (!memrefToTensor)
93  return failure();
94 
95  Type srcType = memrefToTensor.getMemref().getType();
96  Type destType = toMemref.getType();
97 
98  // Directly rewrite if the type did not change.
99  if (srcType == destType) {
100  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101  return success();
102  }
103 
104  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107 
108  // Ranked memref -> Ranked memref cast.
109  if (rankedSrcType && rankedDestType) {
110  FailureOr<Value> replacement = castOrReallocMemRefValue(
111  rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112  if (failed(replacement))
113  return failure();
114 
115  rewriter.replaceOp(toMemref, *replacement);
116  return success();
117  }
118 
119  // Unranked memref -> Ranked memref cast: May require a copy.
120  // TODO: Not implemented at the moment.
121  if (unrankedSrcType && rankedDestType)
122  return failure();
123 
124  // Unranked memref -> unranked memref cast
125  // Ranked memref -> unranked memref cast: No copy needed.
126  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127  "expected that types are cast compatible");
128  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
129  memrefToTensor.getMemref());
130  return success();
131 }
132 
134  OpBuilder &b, Location loc, Value shapedValue,
135  SmallVector<Value> &dynamicDims) {
136  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138  if (shapedType.isDynamicDim(i)) {
139  if (llvm::isa<MemRefType>(shapedType)) {
140  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141  } else {
142  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144  }
145  }
146  }
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // AllocTensorOp
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154  const BufferizationOptions &options) {
155  OpBuilder::InsertionGuard g(rewriter);
156  Location loc = getLoc();
157 
158  // Nothing to do for dead AllocTensorOps.
159  if (getOperation()->getUses().empty()) {
160  rewriter.eraseOp(getOperation());
161  return success();
162  }
163 
164  // Get "copy" buffer.
165  Value copyBuffer;
166  if (getCopy()) {
167  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
172 
173  // Create memory allocation.
174  auto allocType = bufferization::getBufferType(getResult(), options);
175  if (failed(allocType))
176  return failure();
177  SmallVector<Value> dynamicDims = getDynamicSizes();
178  if (getCopy()) {
179  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181  }
182  FailureOr<Value> alloc = options.createAlloc(
183  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184  if (failed(alloc))
185  return failure();
186 
187  // Create memory copy (if any).
188  if (getCopy()) {
189  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190  return failure();
191  }
192 
193  // Replace op.
194  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195 
196  return success();
197 }
198 
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200  const AnalysisState &state) {
201  // AllocTensorOps do not write unless they have a `copy` value.
202  return static_cast<bool>(getCopy());
203 }
204 
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206  const AnalysisState &state) {
207  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208  "expected copy operand");
209  return true;
210 }
211 
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213  const AnalysisState &state) {
214  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215  "expected copy operand");
216  return false;
217 }
218 
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220  const AnalysisState &state) {
221  // This is a new allocation. It does not alias with any other buffer.
222  return {};
223 }
224 
225 FailureOr<BaseMemRefType>
227  SmallVector<Value> &invocationStack) {
228  assert(value == getResult() && "invalid value");
229 
230  // Compute memory space of this allocation.
231  Attribute memorySpace;
232  if (getMemorySpace().has_value()) {
233  memorySpace = *getMemorySpace();
234  } else if (getCopy()) {
235  auto copyBufferType =
236  bufferization::getBufferType(getCopy(), options, invocationStack);
237  if (failed(copyBufferType))
238  return failure();
239  memorySpace = copyBufferType->getMemorySpace();
240  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241  memorySpace = *ms;
242  } else {
243  return getOperation()->emitError("could not infer memory space");
244  }
245 
246  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247 }
248 
249 LogicalResult AllocTensorOp::verify() {
250  if (getCopy() && !getDynamicSizes().empty())
251  return emitError("dynamic sizes not needed when copying a tensor");
252  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
253  return emitError("expected ")
254  << getType().getNumDynamicDims() << " dynamic sizes";
255  if (getCopy() && getCopy().getType() != getType())
256  return emitError("expected that `copy` and return type match");
257  return success();
258 }
259 
260 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
261  RankedTensorType type, ValueRange dynamicSizes) {
262  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
263  /*size_hint=*/Value(),
264  /*memory_space=*/IntegerAttr());
265 }
266 
267 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
268  RankedTensorType type, ValueRange dynamicSizes,
269  Value copy) {
270  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
271  /*memory_space=*/IntegerAttr());
272 }
273 
274 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
275  TensorType type, ValueRange dynamicSizes, Value copy,
276  IntegerAttr memorySpace) {
277  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
278  memorySpace);
279 }
280 
281 namespace {
282 /// Change the type of the result of a `bufferization.alloc_tensor` by making
283 /// the result type statically sized along dimension that in the original
284 /// operation where defined as dynamic, but the size was defined using a
285 /// `constant` op. For example:
286 ///
287 /// %c5 = arith.constant 5: index
288 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
289 ///
290 /// to
291 ///
292 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
293 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
295 
296  LogicalResult matchAndRewrite(AllocTensorOp op,
297  PatternRewriter &rewriter) const override {
298  if (op.getCopy())
299  return failure();
300  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
301  SmallVector<Value> newDynamicSizes;
302  unsigned int dynValCounter = 0;
303  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
304  if (!op.isDynamicDim(i))
305  continue;
306  Value value = op.getDynamicSizes()[dynValCounter++];
307  APInt intVal;
308  if (matchPattern(value, m_ConstantInt(&intVal))) {
309  int64_t dim = intVal.getSExtValue();
310  if (dim >= 0)
311  newShape[i] = intVal.getSExtValue();
312  else
313  newDynamicSizes.push_back(value);
314  } else {
315  newDynamicSizes.push_back(value);
316  }
317  }
318  RankedTensorType newType = RankedTensorType::get(
319  newShape, op.getType().getElementType(), op.getType().getEncoding());
320  if (newType == op.getType())
321  return failure();
322  auto newOp = rewriter.create<AllocTensorOp>(
323  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325  return success();
326  }
327 };
328 
329 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
331 
332  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333  PatternRewriter &rewriter) const override {
334  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
335  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336  if (!allocTensorOp || !maybeConstantIndex)
337  return failure();
338  if (*maybeConstantIndex < 0 ||
339  *maybeConstantIndex >= allocTensorOp.getType().getRank())
340  return failure();
341  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
342  return failure();
343  rewriter.replaceOp(
344  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
345  return success();
346  }
347 };
348 } // namespace
349 
350 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
351  MLIRContext *ctx) {
352  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
353 }
354 
356  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
357  auto shapes = llvm::to_vector<4>(
358  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
359  [&](int64_t dim) -> OpFoldResult {
360  if (isDynamicDim(dim))
361  return getDynamicSize(builder, dim);
362  return builder.getIndexAttr(getStaticSize(dim));
363  }));
364  reifiedReturnShapes.emplace_back(std::move(shapes));
365  return success();
366 }
367 
368 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
369  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
370  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
371  parser.parseRParen())
372  return failure();
373  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
374  OpAsmParser::UnresolvedOperand copyOperand;
375  if (copyKeyword.succeeded())
376  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
377  parser.parseRParen())
378  return failure();
379  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
380  OpAsmParser::UnresolvedOperand sizeHintOperand;
381  if (sizeHintKeyword.succeeded())
382  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
383  return failure();
384  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
385  return failure();
386 
387  TensorType type;
388  if (parser.parseCustomTypeWithFallback(type))
389  return failure();
390  result.addTypes(type);
391 
392  Type indexType = parser.getBuilder().getIndexType();
393  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
394  return failure();
395  if (copyKeyword.succeeded())
396  if (parser.resolveOperand(copyOperand, type, result.operands))
397  return failure();
398  if (sizeHintKeyword.succeeded())
399  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
400  return failure();
401  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
403  {static_cast<int32_t>(dynamicSizesOperands.size()),
404  static_cast<int32_t>(copyKeyword.succeeded()),
405  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
406  return success();
407 }
408 
410  p << "(" << getDynamicSizes() << ")";
411  if (getCopy())
412  p << " copy(" << getCopy() << ")";
413  if (getSizeHint())
414  p << " size_hint=" << getSizeHint();
415  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
416  AllocTensorOp::getOperandSegmentSizeAttr()});
417  p << " : ";
418  auto type = getResult().getType();
419  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
420  p.printStrippedAttrOrType(validType);
421  else
422  p << type;
423 }
424 
425 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
426  assert(isDynamicDim(idx) && "expected dynamic dim");
427  if (getCopy())
428  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
429  return getOperand(getIndexOfDynamicSize(idx));
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // CloneOp
434 //===----------------------------------------------------------------------===//
435 
436 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
437  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
438 }
439 
440 namespace {
441 
442 /// Merge the clone and its source (by converting the clone to a cast) when
443 /// possible.
444 struct SimplifyClones : public OpRewritePattern<CloneOp> {
446 
447  LogicalResult matchAndRewrite(CloneOp cloneOp,
448  PatternRewriter &rewriter) const override {
449  if (cloneOp.use_empty()) {
450  rewriter.eraseOp(cloneOp);
451  return success();
452  }
453 
454  Value source = cloneOp.getInput();
455  if (source.getType() != cloneOp.getType() &&
456  !memref::CastOp::areCastCompatible({source.getType()},
457  {cloneOp.getType()}))
458  return failure();
459 
460  // Aims to find the dealloc op for the canonical source
461  // which otherwise could prevent removal of unnecessary allocs.
462  Value canonicalSource = source;
463  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
464  canonicalSource.getDefiningOp()))
465  canonicalSource = iface.getViewSource();
466 
467  std::optional<Operation *> maybeCloneDeallocOp =
468  memref::findDealloc(cloneOp.getOutput());
469  // Skip if either of them has > 1 deallocate operations.
470  if (!maybeCloneDeallocOp.has_value())
471  return failure();
472  std::optional<Operation *> maybeSourceDeallocOp =
473  memref::findDealloc(canonicalSource);
474  if (!maybeSourceDeallocOp.has_value())
475  return failure();
476  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
477  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
478 
479  // If both are deallocated in the same block, their in-block lifetimes
480  // might not fully overlap, so we cannot decide which one to drop.
481  if (cloneDeallocOp && sourceDeallocOp &&
482  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
483  return failure();
484 
485  Block *currentBlock = cloneOp->getBlock();
486  Operation *redundantDealloc = nullptr;
487  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
488  redundantDealloc = cloneDeallocOp;
489  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
490  redundantDealloc = sourceDeallocOp;
491  }
492 
493  if (!redundantDealloc)
494  return failure();
495 
496  // Safety check that there are no other deallocations inbetween
497  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
498  // of source before the uses of the clone. With alias information, we could
499  // restrict this to only fail of the dealloc's operand is an alias
500  // of the source.
501  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
502  pos = pos->getNextNode()) {
503  // Bail if we run out of operations while looking for a deallocation op.
504  if (!pos)
505  return failure();
506  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
507  if (!effectInterface)
508  continue;
509  if (effectInterface.hasEffect<MemoryEffects::Free>())
510  return failure();
511  }
512 
513  if (source.getType() != cloneOp.getType())
514  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
515  cloneOp.getType(), source);
516  rewriter.replaceOp(cloneOp, source);
517  rewriter.eraseOp(redundantDealloc);
518  return success();
519  }
520 };
521 
522 } // namespace
523 
524 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
525  MLIRContext *context) {
526  results.add<SimplifyClones>(context);
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // DeallocTensorOp
531 //===----------------------------------------------------------------------===//
532 
533 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
534  const BufferizationOptions &options) {
535  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
536  if (failed(buffer))
537  return failure();
538  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
539  rewriter.eraseOp(getOperation());
540  return success();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // MaterializeInDestinationOp
545 //===----------------------------------------------------------------------===//
546 
547 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
548  OpOperand &opOperand, const AnalysisState &state) {
549  return opOperand == getSourceMutable();
550 }
551 
552 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
553  OpOperand &opOperand, const AnalysisState &state) {
554  if (opOperand == getDestMutable()) {
555  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
556  return true;
557  }
558  return false;
559 }
560 
561 bool MaterializeInDestinationOp::mustBufferizeInPlace(
562  OpOperand &opOperand, const AnalysisState &state) {
563  // The source is only read and not written, so it always bufferizes in-place
564  // by default. The destination is written and is forced to bufferize in-place
565  // (if it is a tensor).
566  return true;
567 }
568 
570 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
571  const AnalysisState &state) {
572  if (opOperand == getDestMutable()) {
573  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
574  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
575  }
576  return {};
577 }
578 
579 LogicalResult
580 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
581  const BufferizationOptions &options) {
582  bool tensorDest = isa<TensorType>(getDest().getType());
583  Value buffer;
584  if (tensorDest) {
585  FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
586  if (failed(maybeBuffer))
587  return failure();
588  buffer = *maybeBuffer;
589  } else {
590  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
591  buffer = getDest();
592  }
593  auto srcBuffer = getBuffer(rewriter, getSource(), options);
594  if (failed(srcBuffer))
595  return failure();
596  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
597  return failure();
598  replaceOpWithBufferizedValues(rewriter, getOperation(),
599  tensorDest ? ValueRange(buffer) : ValueRange());
600  return success();
601 }
602 
603 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
604  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
605  // As elements are copied from the "source" buffer to the "dest" buffer,
606  // already copied elements are not read a second time.
607  return true;
608 }
609 
611  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
612  if (getOperation()->getNumResults() == 1) {
613  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
614  reifiedReturnShapes.resize(1,
615  SmallVector<OpFoldResult>(getType().getRank()));
616  reifiedReturnShapes[0] =
617  tensor::getMixedSizes(builder, getLoc(), getDest());
618  }
619  return success();
620 }
621 
622 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
623  Location loc) {
624  if (isa<TensorType>(getDest().getType())) {
625  // The subset is the entire destination tensor.
626  return getDest();
627  }
628 
629  // The "restrict" attribute is transferred from this op to the newly created
630  // to_tensor op. If this op does not the "restrict" attribute, the subset
631  // extraction cannot be built because there is no guarantee that there is no
632  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
633  if (!getRestrict())
634  return {};
635 
636  // Build a bufferization.to_tensor op.
637  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
638  assert(getRestrict() &&
639  "expected that ops with memrefs dest have 'restrict'");
640  setRestrict(false);
641  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
642  getWritable());
643 }
644 
645 bool MaterializeInDestinationOp::isEquivalentSubset(
646  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
647  return equivalenceFn(getDest(), candidate);
648 }
649 
651 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
652  return {getDest()};
653 }
654 
655 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
656  return getOperation()->getOpOperand(0) /*source*/;
657 }
658 
659 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
660  SubsetOpInterface subsetOp,
661  function_ref<bool(Value, Value)> equivalenceFn) {
662  return false;
663 }
664 
665 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
666  SubsetOpInterface subsetOp,
667  function_ref<bool(Value, Value)> equivalenceFn) {
668  return false;
669 }
670 
671 LogicalResult MaterializeInDestinationOp::verify() {
672  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
673  return emitOpError("'dest' must be a tensor or a memref");
674  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
675  if (getOperation()->getNumResults() != 1)
676  return emitOpError("tensor 'dest' implies exactly one tensor result");
677  if (destType != getResult().getType())
678  return emitOpError("result and 'dest' types must match");
679  }
680  if (isa<BaseMemRefType>(getDest().getType()) &&
681  getOperation()->getNumResults() != 0)
682  return emitOpError("memref 'dest' implies zero results");
683  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
684  return emitOpError("'restrict' is valid only for memref destinations");
685  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
686  return emitOpError("'writable' must be specified if and only if the "
687  "destination is of memref type");
688  TensorType srcType = getSource().getType();
689  ShapedType destType = cast<ShapedType>(getDest().getType());
690  if (srcType.hasRank() != destType.hasRank())
691  return emitOpError("source/destination shapes are incompatible");
692  if (srcType.hasRank()) {
693  if (srcType.getRank() != destType.getRank())
694  return emitOpError("rank mismatch between source and destination shape");
695  for (auto [src, dest] :
696  llvm::zip(srcType.getShape(), destType.getShape())) {
697  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
698  // Cannot verify dynamic dimension size. Assume that that they match at
699  // runtime.
700  continue;
701  }
702  if (src != dest)
703  return emitOpError("source/destination shapes are incompatible");
704  }
705  }
706  return success();
707 }
708 
709 void MaterializeInDestinationOp::build(OpBuilder &builder,
710  OperationState &state, Value source,
711  Value dest) {
712  auto destTensorType = dyn_cast<TensorType>(dest.getType());
713  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
714  source, dest);
715 }
716 
717 bool MaterializeInDestinationOp::isWritable(Value value,
718  const AnalysisState &state) {
719  return isa<TensorType>(getDest().getType()) ? true : getWritable();
720 }
721 
722 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
723  return getDestMutable();
724 }
725 
726 void MaterializeInDestinationOp::getEffects(
728  &effects) {
729  if (isa<BaseMemRefType>(getDest().getType()))
730  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // ToTensorOp
736 //===----------------------------------------------------------------------===//
737 
738 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
739  return getWritable();
740 }
741 
742 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
743  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
744  // Approximate alias analysis by conservatively folding only when no there
745  // is no interleaved operation.
746  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
747  toMemref->getNextNode() == this->getOperation())
748  return toMemref.getTensor();
749  return {};
750 }
751 
752 namespace {
753 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
755 
756  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
757  PatternRewriter &rewriter) const override {
758  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
759  if (!memrefToTensorOp)
760  return failure();
761 
762  rewriter.replaceOpWithNewOp<memref::DimOp>(
763  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
764  return success();
765  }
766 };
767 } // namespace
768 
769 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
770  MLIRContext *context) {
771  results.add<DimOfToTensorFolder>(context);
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // ToMemrefOp
776 //===----------------------------------------------------------------------===//
777 
778 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
779  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
780  if (memrefToTensor.getMemref().getType() == getType())
781  return memrefToTensor.getMemref();
782  return {};
783 }
784 
785 namespace {
786 
787 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
788 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
790 
791  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
792  PatternRewriter &rewriter) const final {
793  auto tensorCastOperand =
794  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
795  if (!tensorCastOperand)
796  return failure();
797  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
798  tensorCastOperand.getOperand().getType());
799  if (!srcTensorType)
800  return failure();
801  auto memrefType = MemRefType::get(srcTensorType.getShape(),
802  srcTensorType.getElementType());
803  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
804  tensorCastOperand.getOperand());
805  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
806  memref);
807  return success();
808  }
809 };
810 
811 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
812 /// cast if necessary.
813 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
815 
816  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
817  PatternRewriter &rewriter) const final {
819  options.bufferAlignment = 0;
820  return foldToMemrefToTensorPair(rewriter, toMemref, options);
821  }
822 };
823 
824 /// Fold a load on a to_memref operation into an tensor.extract on the
825 /// corresponding tensor.
826 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
828 
829  LogicalResult matchAndRewrite(memref::LoadOp load,
830  PatternRewriter &rewriter) const override {
831  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
832  if (!toMemref)
833  return failure();
834 
835  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
836  load.getIndices());
837  return success();
838  }
839 };
840 
841 /// Fold dim of a to_memref into the dim of the tensor.
842 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
844 
845  LogicalResult matchAndRewrite(memref::DimOp dimOp,
846  PatternRewriter &rewriter) const override {
847  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
848  if (!castOp)
849  return failure();
850  Value newSource = castOp.getOperand();
851  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
852  dimOp.getIndex());
853  return success();
854  }
855 };
856 
857 } // namespace
858 
859 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
860  MLIRContext *context) {
861  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
862  ToMemrefToTensorFolding>(context);
863 }
864 
865 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
866  const BufferizationOptions &options) {
867  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
868  (void)foldToMemrefToTensorPair(rewriter, *this, options);
869  // Note: The return value of `bufferize` indicates whether there was an error
870  // or not. (And not whether the pattern matched or not.)
871  return success();
872 }
873 
874 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
875  Value alloc) {
876  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
877  .getOperation();
878 }
879 
880 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
881  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // DeallocOp
886 //===----------------------------------------------------------------------===//
887 
888 LogicalResult DeallocOp::inferReturnTypes(
889  MLIRContext *context, std::optional<::mlir::Location> location,
890  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
891  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
892  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
893  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
894  IntegerType::get(context, 1));
895  return success();
896 }
897 
898 LogicalResult DeallocOp::verify() {
899  if (getMemrefs().size() != getConditions().size())
900  return emitOpError(
901  "must have the same number of conditions as memrefs to deallocate");
902  if (getRetained().size() != getUpdatedConditions().size())
903  return emitOpError("must have the same number of updated conditions "
904  "(results) as retained operands");
905  return success();
906 }
907 
908 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
909  ValueRange memrefs,
910  ValueRange conditions,
911  PatternRewriter &rewriter) {
912  if (deallocOp.getMemrefs() == memrefs &&
913  deallocOp.getConditions() == conditions)
914  return failure();
915 
916  rewriter.modifyOpInPlace(deallocOp, [&]() {
917  deallocOp.getMemrefsMutable().assign(memrefs);
918  deallocOp.getConditionsMutable().assign(conditions);
919  });
920  return success();
921 }
922 
923 namespace {
924 
925 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
926 /// make sure the corresponding condition value is updated accordingly since
927 /// their two conditions might not cover the same set of cases. In that case, we
928 /// have to combine them (by computing the disjunction of them).
929 /// Example:
930 /// ```mlir
931 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
932 /// ```
933 /// is canonicalized to
934 /// ```mlir
935 /// %0 = arith.ori %arg1, %arg2 : i1
936 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
937 /// ```
938 struct DeallocRemoveDuplicateDeallocMemrefs
939  : public OpRewritePattern<DeallocOp> {
941 
942  LogicalResult matchAndRewrite(DeallocOp deallocOp,
943  PatternRewriter &rewriter) const override {
944  // Unique memrefs to be deallocated.
945  DenseMap<Value, unsigned> memrefToCondition;
946  SmallVector<Value> newMemrefs, newConditions;
947  for (auto [i, memref, cond] :
948  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
949  if (memrefToCondition.count(memref)) {
950  // If the dealloc conditions don't match, we need to make sure that the
951  // dealloc happens on the union of cases.
952  Value &newCond = newConditions[memrefToCondition[memref]];
953  if (newCond != cond)
954  newCond =
955  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
956  } else {
957  memrefToCondition.insert({memref, newConditions.size()});
958  newMemrefs.push_back(memref);
959  newConditions.push_back(cond);
960  }
961  }
962 
963  // Return failure if we don't change anything such that we don't run into an
964  // infinite loop of pattern applications.
965  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
966  rewriter);
967  }
968 };
969 
970 /// Remove duplicate values in the list of retained memrefs. We need to make
971 /// sure the corresponding result condition value is replaced properly.
972 /// Example:
973 /// ```mlir
974 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
975 /// ```
976 /// is canonicalized to
977 /// ```mlir
978 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
979 /// ```
980 struct DeallocRemoveDuplicateRetainedMemrefs
981  : public OpRewritePattern<DeallocOp> {
983 
984  LogicalResult matchAndRewrite(DeallocOp deallocOp,
985  PatternRewriter &rewriter) const override {
986  // Unique retained values
988  SmallVector<Value> newRetained;
989  SmallVector<unsigned> resultReplacementIdx;
990  unsigned i = 0;
991  for (auto retained : deallocOp.getRetained()) {
992  if (seen.count(retained)) {
993  resultReplacementIdx.push_back(seen[retained]);
994  continue;
995  }
996 
997  seen[retained] = i;
998  newRetained.push_back(retained);
999  resultReplacementIdx.push_back(i++);
1000  }
1001 
1002  // Return failure if we don't change anything such that we don't run into an
1003  // infinite loop of pattern applications.
1004  if (newRetained.size() == deallocOp.getRetained().size())
1005  return failure();
1006 
1007  // We need to create a new op because the number of results is always the
1008  // same as the number of condition operands.
1009  auto newDeallocOp =
1010  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1011  deallocOp.getConditions(), newRetained);
1012  SmallVector<Value> replacements(
1013  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1014  return newDeallocOp.getUpdatedConditions()[idx];
1015  }));
1016  rewriter.replaceOp(deallocOp, replacements);
1017  return success();
1018  }
1019 };
1020 
1021 /// Erase deallocation operations where the variadic list of memrefs to
1022 /// deallocate is empty. Example:
1023 /// ```mlir
1024 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1025 /// ```
1026 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1028 
1029  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1030  PatternRewriter &rewriter) const override {
1031  if (deallocOp.getMemrefs().empty()) {
1032  Value constFalse = rewriter.create<arith::ConstantOp>(
1033  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1034  rewriter.replaceOp(
1035  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1036  constFalse));
1037  return success();
1038  }
1039  return failure();
1040  }
1041 };
1042 
1043 /// Removes memrefs from the deallocation list if their associated condition is
1044 /// always 'false'.
1045 ///
1046 /// Example:
1047 /// ```
1048 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1049 /// if (%arg2, %false)
1050 /// ```
1051 /// becomes
1052 /// ```
1053 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1054 /// ```
1055 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1057 
1058  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1059  PatternRewriter &rewriter) const override {
1060  SmallVector<Value> newMemrefs, newConditions;
1061  for (auto [memref, cond] :
1062  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1063  if (!matchPattern(cond, m_Zero())) {
1064  newMemrefs.push_back(memref);
1065  newConditions.push_back(cond);
1066  }
1067  }
1068 
1069  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1070  rewriter);
1071  }
1072 };
1073 
1074 /// The `memref.extract_strided_metadata` is often inserted to get the base
1075 /// memref if the operand is not already guaranteed to be the result of a memref
1076 /// allocation operation. This canonicalization pattern removes this extraction
1077 /// operation if the operand is now produced by an allocation operation (e.g.,
1078 /// due to other canonicalizations simplifying the IR).
1079 ///
1080 /// Example:
1081 /// ```mlir
1082 /// %alloc = memref.alloc() : memref<2xi32>
1083 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1084 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1085 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1086 /// ```
1087 /// is canonicalized to
1088 /// ```mlir
1089 /// %alloc = memref.alloc() : memref<2xi32>
1090 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1091 /// ```
1092 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1094 
1095  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1096  PatternRewriter &rewriter) const override {
1097  SmallVector<Value> newMemrefs(
1098  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1099  auto extractStridedOp =
1100  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1101  if (!extractStridedOp)
1102  return memref;
1103  Value allocMemref = extractStridedOp.getOperand();
1104  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1105  if (!allocOp)
1106  return memref;
1107  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1108  return allocMemref;
1109  return memref;
1110  }));
1111 
1112  return updateDeallocIfChanged(deallocOp, newMemrefs,
1113  deallocOp.getConditions(), rewriter);
1114  }
1115 };
1116 
1117 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1118 /// other user of the allocated value and the allocating operation can be safely
1119 /// removed. If the same value is present multiple times, this pattern relies on
1120 /// other canonicalization patterns to remove the duplicate first.
1121 ///
1122 /// Example:
1123 /// ```mlir
1124 /// %alloc = memref.alloc() : memref<2xi32>
1125 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1126 /// ```
1127 /// is canonicalized to
1128 /// ```mlir
1129 /// bufferization.dealloc (%arg0 : ...) if (%true)
1130 /// ```
1131 struct RemoveAllocDeallocPairWhenNoOtherUsers
1132  : public OpRewritePattern<DeallocOp> {
1134 
1135  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1136  PatternRewriter &rewriter) const override {
1137  SmallVector<Value> newMemrefs, newConditions;
1138  SmallVector<Operation *> toDelete;
1139  for (auto [memref, cond] :
1140  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1141  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1142  // Check that it is indeed an allocate effect, that the op has no other
1143  // side effects (which would not allow us to remove the op), and that
1144  // there are no other users.
1145  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1146  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1147  memref.hasOneUse()) {
1148  toDelete.push_back(allocOp);
1149  continue;
1150  }
1151  }
1152 
1153  newMemrefs.push_back(memref);
1154  newConditions.push_back(cond);
1155  }
1156 
1157  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1158  rewriter)))
1159  return failure();
1160 
1161  for (Operation *op : toDelete)
1162  rewriter.eraseOp(op);
1163 
1164  return success();
1165  }
1166 };
1167 
1168 } // anonymous namespace
1169 
1170 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1171  MLIRContext *context) {
1173 }
1174 
1176  RewritePatternSet &patterns, MLIRContext *context) {
1177  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1178  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1179  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1180  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // TableGen'd op method definitions
1185 //===----------------------------------------------------------------------===//
1186 
1187 #define GET_OP_CLASSES
1188 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.