MLIR  21.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type and rank must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getRank() != destType.getRank())
35  return failure();
36 
37  // In case the affine maps are different, we may need to use a copy if we go
38  // from dynamic to static offset or stride (the canonicalization cannot know
39  // at this point that it is really cast compatible).
40  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41  int64_t sourceOffset, targetOffset;
42  SmallVector<int64_t, 4> sourceStrides, targetStrides;
43  if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
44  failed(target.getStridesAndOffset(targetStrides, targetOffset)))
45  return false;
46  auto dynamicToStatic = [](int64_t a, int64_t b) {
47  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
48  };
49  if (dynamicToStatic(sourceOffset, targetOffset))
50  return false;
51  for (auto it : zip(sourceStrides, targetStrides))
52  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
53  return false;
54  return true;
55  };
56 
57  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
58  // ensure that we only generate casts that always succeed at runtime, we check
59  // a fix extra conditions in `isGuaranteedCastCompatible`.
60  if (memref::CastOp::areCastCompatible(srcType, destType) &&
61  isGuaranteedCastCompatible(srcType, destType)) {
62  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
63  return casted;
64  }
65 
66  auto loc = value.getLoc();
67  SmallVector<Value, 4> dynamicOperands;
68  for (int i = 0; i < destType.getRank(); ++i) {
69  if (destType.getShape()[i] != ShapedType::kDynamic)
70  continue;
71  Value size = b.create<memref::DimOp>(loc, value, i);
72  dynamicOperands.push_back(size);
73  }
74 
75  FailureOr<Value> copy =
76  options.createAlloc(b, loc, destType, dynamicOperands);
77  if (failed(copy))
78  return failure();
79  if (failed(options.createMemCpy(b, loc, value, *copy)))
80  return failure();
81  return copy;
82 }
83 
84 /// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
85 /// to_buffer op are different, a memref.cast is needed.
87  RewriterBase &rewriter, ToBufferOp toBuffer,
89  auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
90  if (!bufferToTensor)
91  return failure();
92 
93  Type srcType = bufferToTensor.getBuffer().getType();
94  Type destType = toBuffer.getType();
95 
96  // Directly rewrite if the type did not change.
97  if (srcType == destType) {
98  rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
99  return success();
100  }
101 
102  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
103  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
104  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
105 
106  // Ranked memref -> Ranked memref cast.
107  if (rankedSrcType && rankedDestType) {
108  FailureOr<Value> replacement = castOrReallocMemRefValue(
109  rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
110  if (failed(replacement))
111  return failure();
112 
113  rewriter.replaceOp(toBuffer, *replacement);
114  return success();
115  }
116 
117  // Unranked memref -> Ranked memref cast: May require a copy.
118  // TODO: Not implemented at the moment.
119  if (unrankedSrcType && rankedDestType)
120  return failure();
121 
122  // Unranked memref -> unranked memref cast
123  // Ranked memref -> unranked memref cast: No copy needed.
124  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125  "expected that types are cast compatible");
126  rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
127  bufferToTensor.getBuffer());
128  return success();
129 }
130 
132  OpBuilder &b, Location loc, Value shapedValue,
133  SmallVector<Value> &dynamicDims) {
134  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
135  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
136  if (shapedType.isDynamicDim(i)) {
137  if (llvm::isa<MemRefType>(shapedType)) {
138  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
139  } else {
140  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
141  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
142  }
143  }
144  }
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // AllocTensorOp
149 //===----------------------------------------------------------------------===//
150 
151 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153  BufferizationState &state) {
154  OpBuilder::InsertionGuard g(rewriter);
155  Location loc = getLoc();
156 
157  // Nothing to do for dead AllocTensorOps.
158  if (getOperation()->getUses().empty()) {
159  rewriter.eraseOp(getOperation());
160  return success();
161  }
162 
163  // Get "copy" buffer.
164  Value copyBuffer;
165  if (getCopy()) {
166  FailureOr<Value> maybeCopyBuffer =
167  getBuffer(rewriter, getCopy(), options, state);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
172 
173  // Create memory allocation.
174  auto allocType = bufferization::getBufferType(getResult(), options, state);
175  if (failed(allocType))
176  return failure();
177  SmallVector<Value> dynamicDims = getDynamicSizes();
178  if (getCopy()) {
179  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181  }
182  FailureOr<Value> alloc = options.createAlloc(
183  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184  if (failed(alloc))
185  return failure();
186 
187  // Create memory copy (if any).
188  if (getCopy()) {
189  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190  return failure();
191  }
192 
193  // Replace op.
194  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195 
196  return success();
197 }
198 
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200  const AnalysisState &state) {
201  // AllocTensorOps do not write unless they have a `copy` value.
202  return static_cast<bool>(getCopy());
203 }
204 
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206  const AnalysisState &state) {
207  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208  "expected copy operand");
209  return true;
210 }
211 
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213  const AnalysisState &state) {
214  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215  "expected copy operand");
216  return false;
217 }
218 
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220  const AnalysisState &state) {
221  // This is a new allocation. It does not alias with any other buffer.
222  return {};
223 }
224 
225 FailureOr<BaseMemRefType>
227  const BufferizationState &state,
228  SmallVector<Value> &invocationStack) {
229  assert(value == getResult() && "invalid value");
230 
231  // Compute memory space of this allocation.
232  Attribute memorySpace;
233  if (getMemorySpace().has_value()) {
234  memorySpace = *getMemorySpace();
235  } else if (getCopy()) {
236  auto copyBufferType =
238  getCopy(), options, state, invocationStack));
239  if (failed(copyBufferType))
240  return failure();
241  memorySpace = copyBufferType->getMemorySpace();
242  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
243  memorySpace = *ms;
244  } else {
245  return getOperation()->emitError("could not infer memory space");
246  }
247 
248  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
249 }
250 
251 LogicalResult AllocTensorOp::verify() {
252  if (getCopy() && !getDynamicSizes().empty())
253  return emitError("dynamic sizes not needed when copying a tensor");
254  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
255  return emitError("expected ")
256  << getType().getNumDynamicDims() << " dynamic sizes";
257  if (getCopy() && getCopy().getType() != getType())
258  return emitError("expected that `copy` and return type match");
259  return success();
260 }
261 
262 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
263  RankedTensorType type, ValueRange dynamicSizes) {
264  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
265  /*size_hint=*/Value(),
266  /*memory_space=*/IntegerAttr());
267 }
268 
269 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
270  RankedTensorType type, ValueRange dynamicSizes,
271  Value copy) {
272  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
273  /*memory_space=*/IntegerAttr());
274 }
275 
276 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
277  TensorType type, ValueRange dynamicSizes, Value copy,
278  IntegerAttr memorySpace) {
279  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
280  memorySpace);
281 }
282 
283 namespace {
284 /// Change the type of the result of a `bufferization.alloc_tensor` by making
285 /// the result type statically sized along dimension that in the original
286 /// operation where defined as dynamic, but the size was defined using a
287 /// `constant` op. For example:
288 ///
289 /// %c5 = arith.constant 5: index
290 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
291 ///
292 /// to
293 ///
294 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
295 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
297 
298  LogicalResult matchAndRewrite(AllocTensorOp op,
299  PatternRewriter &rewriter) const override {
300  if (op.getCopy())
301  return failure();
302  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
303  SmallVector<Value> newDynamicSizes;
304  unsigned int dynValCounter = 0;
305  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
306  if (!op.isDynamicDim(i))
307  continue;
308  Value value = op.getDynamicSizes()[dynValCounter++];
309  APInt intVal;
310  if (matchPattern(value, m_ConstantInt(&intVal))) {
311  int64_t dim = intVal.getSExtValue();
312  if (dim >= 0)
313  newShape[i] = intVal.getSExtValue();
314  else
315  newDynamicSizes.push_back(value);
316  } else {
317  newDynamicSizes.push_back(value);
318  }
319  }
320  RankedTensorType newType = RankedTensorType::get(
321  newShape, op.getType().getElementType(), op.getType().getEncoding());
322  if (newType == op.getType())
323  return failure();
324  auto newOp = rewriter.create<AllocTensorOp>(
325  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
326  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
327  return success();
328  }
329 };
330 
331 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
333 
334  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
335  PatternRewriter &rewriter) const override {
336  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
337  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
338  if (!allocTensorOp || !maybeConstantIndex)
339  return failure();
340  if (*maybeConstantIndex < 0 ||
341  *maybeConstantIndex >= allocTensorOp.getType().getRank())
342  return failure();
343  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
344  return failure();
345  rewriter.replaceOp(
346  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
347  return success();
348  }
349 };
350 } // namespace
351 
352 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
353  MLIRContext *ctx) {
354  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
355 }
356 
358  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
359  auto shapes = llvm::to_vector<4>(
360  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
361  [&](int64_t dim) -> OpFoldResult {
362  if (isDynamicDim(dim))
363  return getDynamicSize(builder, dim);
364  return builder.getIndexAttr(getStaticSize(dim));
365  }));
366  reifiedReturnShapes.emplace_back(std::move(shapes));
367  return success();
368 }
369 
370 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
371  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
372  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
373  parser.parseRParen())
374  return failure();
375  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
376  OpAsmParser::UnresolvedOperand copyOperand;
377  if (copyKeyword.succeeded())
378  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
379  parser.parseRParen())
380  return failure();
381  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
382  OpAsmParser::UnresolvedOperand sizeHintOperand;
383  if (sizeHintKeyword.succeeded())
384  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
385  return failure();
386  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
387  return failure();
388 
389  TensorType type;
390  if (parser.parseCustomTypeWithFallback(type))
391  return failure();
392  result.addTypes(type);
393 
394  Type indexType = parser.getBuilder().getIndexType();
395  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
396  return failure();
397  if (copyKeyword.succeeded())
398  if (parser.resolveOperand(copyOperand, type, result.operands))
399  return failure();
400  if (sizeHintKeyword.succeeded())
401  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
402  return failure();
403  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
405  {static_cast<int32_t>(dynamicSizesOperands.size()),
406  static_cast<int32_t>(copyKeyword.succeeded()),
407  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
408  return success();
409 }
410 
412  p << "(" << getDynamicSizes() << ")";
413  if (getCopy())
414  p << " copy(" << getCopy() << ")";
415  if (getSizeHint())
416  p << " size_hint=" << getSizeHint();
417  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
418  AllocTensorOp::getOperandSegmentSizeAttr()});
419  p << " : ";
420  auto type = getResult().getType();
421  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
422  p.printStrippedAttrOrType(validType);
423  else
424  p << type;
425 }
426 
427 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
428  assert(isDynamicDim(idx) && "expected dynamic dim");
429  if (getCopy())
430  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
431  return getOperand(getIndexOfDynamicSize(idx));
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // CloneOp
436 //===----------------------------------------------------------------------===//
437 
438 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
439  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
440 }
441 
442 namespace {
443 
444 /// Merge the clone and its source (by converting the clone to a cast) when
445 /// possible.
446 struct SimplifyClones : public OpRewritePattern<CloneOp> {
448 
449  LogicalResult matchAndRewrite(CloneOp cloneOp,
450  PatternRewriter &rewriter) const override {
451  if (cloneOp.use_empty()) {
452  rewriter.eraseOp(cloneOp);
453  return success();
454  }
455 
456  Value source = cloneOp.getInput();
457  if (source.getType() != cloneOp.getType() &&
458  !memref::CastOp::areCastCompatible({source.getType()},
459  {cloneOp.getType()}))
460  return failure();
461 
462  // Aims to find the dealloc op for the canonical source
463  // which otherwise could prevent removal of unnecessary allocs.
464  Value canonicalSource = source;
465  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
466  canonicalSource.getDefiningOp()))
467  canonicalSource = iface.getViewSource();
468 
469  std::optional<Operation *> maybeCloneDeallocOp =
470  memref::findDealloc(cloneOp.getOutput());
471  // Skip if either of them has > 1 deallocate operations.
472  if (!maybeCloneDeallocOp.has_value())
473  return failure();
474  std::optional<Operation *> maybeSourceDeallocOp =
475  memref::findDealloc(canonicalSource);
476  if (!maybeSourceDeallocOp.has_value())
477  return failure();
478  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
479  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
480 
481  // If both are deallocated in the same block, their in-block lifetimes
482  // might not fully overlap, so we cannot decide which one to drop.
483  if (cloneDeallocOp && sourceDeallocOp &&
484  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
485  return failure();
486 
487  Block *currentBlock = cloneOp->getBlock();
488  Operation *redundantDealloc = nullptr;
489  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
490  redundantDealloc = cloneDeallocOp;
491  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
492  redundantDealloc = sourceDeallocOp;
493  }
494 
495  if (!redundantDealloc)
496  return failure();
497 
498  // Safety check that there are no other deallocations inbetween
499  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
500  // of source before the uses of the clone. With alias information, we could
501  // restrict this to only fail of the dealloc's operand is an alias
502  // of the source.
503  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
504  pos = pos->getNextNode()) {
505  // Bail if we run out of operations while looking for a deallocation op.
506  if (!pos)
507  return failure();
508  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
509  if (!effectInterface)
510  continue;
511  if (effectInterface.hasEffect<MemoryEffects::Free>())
512  return failure();
513  }
514 
515  if (source.getType() != cloneOp.getType())
516  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
517  cloneOp.getType(), source);
518  rewriter.replaceOp(cloneOp, source);
519  rewriter.eraseOp(redundantDealloc);
520  return success();
521  }
522 };
523 
524 } // namespace
525 
526 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
527  MLIRContext *context) {
528  results.add<SimplifyClones>(context);
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // DeallocTensorOp
533 //===----------------------------------------------------------------------===//
534 
535 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
537  BufferizationState &state) {
538  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
539  if (failed(buffer))
540  return failure();
541  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
542  rewriter.eraseOp(getOperation());
543  return success();
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // MaterializeInDestinationOp
548 //===----------------------------------------------------------------------===//
549 
550 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
551  OpOperand &opOperand, const AnalysisState &state) {
552  return opOperand == getSourceMutable();
553 }
554 
555 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
556  OpOperand &opOperand, const AnalysisState &state) {
557  if (opOperand == getDestMutable()) {
558  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
559  return true;
560  }
561  return false;
562 }
563 
564 bool MaterializeInDestinationOp::mustBufferizeInPlace(
565  OpOperand &opOperand, const AnalysisState &state) {
566  // The source is only read and not written, so it always bufferizes in-place
567  // by default. The destination is written and is forced to bufferize in-place
568  // (if it is a tensor).
569  return true;
570 }
571 
573 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
574  const AnalysisState &state) {
575  if (opOperand == getDestMutable()) {
576  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
577  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
578  }
579  return {};
580 }
581 
582 LogicalResult
583 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
585  BufferizationState &state) {
586  bool tensorDest = isa<TensorType>(getDest().getType());
587  Value buffer;
588  if (tensorDest) {
589  FailureOr<Value> maybeBuffer =
590  getBuffer(rewriter, getDest(), options, state);
591  if (failed(maybeBuffer))
592  return failure();
593  buffer = *maybeBuffer;
594  } else {
595  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
596  buffer = getDest();
597  }
598  auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
599  if (failed(srcBuffer))
600  return failure();
601  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
602  return failure();
603  replaceOpWithBufferizedValues(rewriter, getOperation(),
604  tensorDest ? ValueRange(buffer) : ValueRange());
605  return success();
606 }
607 
608 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
609  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
610  // As elements are copied from the "source" buffer to the "dest" buffer,
611  // already copied elements are not read a second time.
612  return true;
613 }
614 
616  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
617  if (getOperation()->getNumResults() == 1) {
618  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
619  reifiedReturnShapes.resize(1,
620  SmallVector<OpFoldResult>(getType().getRank()));
621  reifiedReturnShapes[0] =
622  tensor::getMixedSizes(builder, getLoc(), getDest());
623  }
624  return success();
625 }
626 
628  Location loc) {
629  if (isa<TensorType>(getDest().getType())) {
630  // The subset is the entire destination tensor.
631  return getDest();
632  }
633 
634  // The "restrict" attribute is transferred from this op to the newly created
635  // to_tensor op. If this op does not the "restrict" attribute, the subset
636  // extraction cannot be built because there is no guarantee that there is no
637  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
638  if (!getRestrict())
639  return {};
640 
641  // Build a bufferization.to_tensor op.
642  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
643  assert(getRestrict() &&
644  "expected that ops with memrefs dest have 'restrict'");
645  setRestrict(false);
646  return builder.create<ToTensorOp>(
647  loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(),
648  /*restrict=*/true, getWritable());
649 }
650 
651 bool MaterializeInDestinationOp::isEquivalentSubset(
652  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
653  return equivalenceFn(getDest(), candidate);
654 }
655 
657 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
658  return {getDest()};
659 }
660 
661 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
662  return getOperation()->getOpOperand(0) /*source*/;
663 }
664 
665 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
666  SubsetOpInterface subsetOp,
667  function_ref<bool(Value, Value)> equivalenceFn) {
668  return false;
669 }
670 
671 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
672  SubsetOpInterface subsetOp,
673  function_ref<bool(Value, Value)> equivalenceFn) {
674  return false;
675 }
676 
677 LogicalResult MaterializeInDestinationOp::verify() {
678  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
679  return emitOpError("'dest' must be a tensor or a memref");
680  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
681  if (getOperation()->getNumResults() != 1)
682  return emitOpError("tensor 'dest' implies exactly one tensor result");
683  if (destType != getResult().getType())
684  return emitOpError("result and 'dest' types must match");
685  }
686  if (isa<BaseMemRefType>(getDest().getType()) &&
687  getOperation()->getNumResults() != 0)
688  return emitOpError("memref 'dest' implies zero results");
689  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
690  return emitOpError("'restrict' is valid only for memref destinations");
691  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
692  return emitOpError("'writable' must be specified if and only if the "
693  "destination is of memref type");
694  TensorType srcType = getSource().getType();
695  ShapedType destType = cast<ShapedType>(getDest().getType());
696  if (srcType.hasRank() != destType.hasRank())
697  return emitOpError("source/destination shapes are incompatible");
698  if (srcType.hasRank()) {
699  if (srcType.getRank() != destType.getRank())
700  return emitOpError("rank mismatch between source and destination shape");
701  for (auto [src, dest] :
702  llvm::zip(srcType.getShape(), destType.getShape())) {
703  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
704  // Cannot verify dynamic dimension size. Assume that that they match at
705  // runtime.
706  continue;
707  }
708  if (src != dest)
709  return emitOpError("source/destination shapes are incompatible");
710  }
711  }
712  return success();
713 }
714 
715 void MaterializeInDestinationOp::build(OpBuilder &builder,
716  OperationState &state, Value source,
717  Value dest) {
718  auto destTensorType = dyn_cast<TensorType>(dest.getType());
719  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
720  source, dest);
721 }
722 
723 bool MaterializeInDestinationOp::isWritable(Value value,
724  const AnalysisState &state) {
725  return isa<TensorType>(getDest().getType()) ? true : getWritable();
726 }
727 
728 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
729  return getDestMutable();
730 }
731 
732 void MaterializeInDestinationOp::getEffects(
734  &effects) {
735  if (isa<BaseMemRefType>(getDest().getType()))
736  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
738 }
739 
740 //===----------------------------------------------------------------------===//
741 // ToTensorOp
742 //===----------------------------------------------------------------------===//
743 
744 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
745  return getWritable();
746 }
747 
748 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
749  if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
750  // Approximate alias analysis by conservatively folding only when no there
751  // is no interleaved operation.
752  if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
753  toBuffer->getNextNode() == this->getOperation())
754  return toBuffer.getTensor();
755  return {};
756 }
757 
758 namespace {
759 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
761 
762  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
763  PatternRewriter &rewriter) const override {
764  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
765  if (!memrefToTensorOp)
766  return failure();
767 
768  rewriter.replaceOpWithNewOp<memref::DimOp>(
769  dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
770  return success();
771  }
772 };
773 } // namespace
774 
775 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
776  MLIRContext *context) {
777  results.add<DimOfToTensorFolder>(context);
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // ToBufferOp
782 //===----------------------------------------------------------------------===//
783 
784 OpFoldResult ToBufferOp::fold(FoldAdaptor) {
785  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
786  if (memrefToTensor.getBuffer().getType() == getType())
787  return memrefToTensor.getBuffer();
788  return {};
789 }
790 
791 namespace {
792 
793 /// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
794 struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
796 
797  LogicalResult matchAndRewrite(ToBufferOp toBuffer,
798  PatternRewriter &rewriter) const final {
799  auto tensorCastOperand =
800  toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
801  if (!tensorCastOperand)
802  return failure();
803  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
804  tensorCastOperand.getOperand().getType());
805  if (!srcTensorType)
806  return failure();
807  auto memrefType = MemRefType::get(srcTensorType.getShape(),
808  srcTensorType.getElementType());
809  Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
810  tensorCastOperand.getOperand());
811  rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
812  memref);
813  return success();
814  }
815 };
816 
817 /// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
818 /// cast if necessary.
819 struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
821 
822  LogicalResult matchAndRewrite(ToBufferOp toBuffer,
823  PatternRewriter &rewriter) const final {
825  options.bufferAlignment = 0;
826  return foldToBufferToTensorPair(rewriter, toBuffer, options);
827  }
828 };
829 
830 /// Fold a load on a to_buffer operation into an tensor.extract on the
831 /// corresponding tensor.
832 struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
834 
835  LogicalResult matchAndRewrite(memref::LoadOp load,
836  PatternRewriter &rewriter) const override {
837  auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
838  if (!toBuffer)
839  return failure();
840 
841  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
842  load.getIndices());
843  return success();
844  }
845 };
846 
847 /// Fold dim of a to_buffer into the dim of the tensor.
848 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
850 
851  LogicalResult matchAndRewrite(memref::DimOp dimOp,
852  PatternRewriter &rewriter) const override {
853  auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
854  if (!castOp)
855  return failure();
856  Value newSource = castOp.getOperand();
857  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
858  dimOp.getIndex());
859  return success();
860  }
861 };
862 
863 } // namespace
864 
865 void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
866  MLIRContext *context) {
867  results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
868  ToBufferToTensorFolding>(context);
869 }
870 
871 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
873  BufferizationState &state) {
874  // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
875  (void)foldToBufferToTensorPair(rewriter, *this, options);
876  // Note: The return value of `bufferize` indicates whether there was an error
877  // or not. (And not whether the pattern matched or not.)
878  return success();
879 }
880 
881 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
882  Value alloc) {
883  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
884  .getOperation();
885 }
886 
887 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
888  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // DeallocOp
893 //===----------------------------------------------------------------------===//
894 
895 LogicalResult DeallocOp::inferReturnTypes(
896  MLIRContext *context, std::optional<::mlir::Location> location,
897  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
898  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
899  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
900  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
901  IntegerType::get(context, 1));
902  return success();
903 }
904 
905 LogicalResult DeallocOp::verify() {
906  if (getMemrefs().size() != getConditions().size())
907  return emitOpError(
908  "must have the same number of conditions as memrefs to deallocate");
909  if (getRetained().size() != getUpdatedConditions().size())
910  return emitOpError("must have the same number of updated conditions "
911  "(results) as retained operands");
912  return success();
913 }
914 
915 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
916  ValueRange memrefs,
917  ValueRange conditions,
918  PatternRewriter &rewriter) {
919  if (deallocOp.getMemrefs() == memrefs &&
920  deallocOp.getConditions() == conditions)
921  return failure();
922 
923  rewriter.modifyOpInPlace(deallocOp, [&]() {
924  deallocOp.getMemrefsMutable().assign(memrefs);
925  deallocOp.getConditionsMutable().assign(conditions);
926  });
927  return success();
928 }
929 
930 namespace {
931 
932 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
933 /// make sure the corresponding condition value is updated accordingly since
934 /// their two conditions might not cover the same set of cases. In that case, we
935 /// have to combine them (by computing the disjunction of them).
936 /// Example:
937 /// ```mlir
938 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
939 /// ```
940 /// is canonicalized to
941 /// ```mlir
942 /// %0 = arith.ori %arg1, %arg2 : i1
943 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
944 /// ```
945 struct DeallocRemoveDuplicateDeallocMemrefs
946  : public OpRewritePattern<DeallocOp> {
948 
949  LogicalResult matchAndRewrite(DeallocOp deallocOp,
950  PatternRewriter &rewriter) const override {
951  // Unique memrefs to be deallocated.
952  DenseMap<Value, unsigned> memrefToCondition;
953  SmallVector<Value> newMemrefs, newConditions;
954  for (auto [i, memref, cond] :
955  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
956  if (memrefToCondition.count(memref)) {
957  // If the dealloc conditions don't match, we need to make sure that the
958  // dealloc happens on the union of cases.
959  Value &newCond = newConditions[memrefToCondition[memref]];
960  if (newCond != cond)
961  newCond =
962  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
963  } else {
964  memrefToCondition.insert({memref, newConditions.size()});
965  newMemrefs.push_back(memref);
966  newConditions.push_back(cond);
967  }
968  }
969 
970  // Return failure if we don't change anything such that we don't run into an
971  // infinite loop of pattern applications.
972  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
973  rewriter);
974  }
975 };
976 
977 /// Remove duplicate values in the list of retained memrefs. We need to make
978 /// sure the corresponding result condition value is replaced properly.
979 /// Example:
980 /// ```mlir
981 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
982 /// ```
983 /// is canonicalized to
984 /// ```mlir
985 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
986 /// ```
987 struct DeallocRemoveDuplicateRetainedMemrefs
988  : public OpRewritePattern<DeallocOp> {
990 
991  LogicalResult matchAndRewrite(DeallocOp deallocOp,
992  PatternRewriter &rewriter) const override {
993  // Unique retained values
995  SmallVector<Value> newRetained;
996  SmallVector<unsigned> resultReplacementIdx;
997  unsigned i = 0;
998  for (auto retained : deallocOp.getRetained()) {
999  if (seen.count(retained)) {
1000  resultReplacementIdx.push_back(seen[retained]);
1001  continue;
1002  }
1003 
1004  seen[retained] = i;
1005  newRetained.push_back(retained);
1006  resultReplacementIdx.push_back(i++);
1007  }
1008 
1009  // Return failure if we don't change anything such that we don't run into an
1010  // infinite loop of pattern applications.
1011  if (newRetained.size() == deallocOp.getRetained().size())
1012  return failure();
1013 
1014  // We need to create a new op because the number of results is always the
1015  // same as the number of condition operands.
1016  auto newDeallocOp =
1017  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1018  deallocOp.getConditions(), newRetained);
1019  SmallVector<Value> replacements(
1020  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1021  return newDeallocOp.getUpdatedConditions()[idx];
1022  }));
1023  rewriter.replaceOp(deallocOp, replacements);
1024  return success();
1025  }
1026 };
1027 
1028 /// Erase deallocation operations where the variadic list of memrefs to
1029 /// deallocate is empty. Example:
1030 /// ```mlir
1031 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1032 /// ```
1033 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1035 
1036  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1037  PatternRewriter &rewriter) const override {
1038  if (deallocOp.getMemrefs().empty()) {
1039  Value constFalse = rewriter.create<arith::ConstantOp>(
1040  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1041  rewriter.replaceOp(
1042  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1043  constFalse));
1044  return success();
1045  }
1046  return failure();
1047  }
1048 };
1049 
1050 /// Removes memrefs from the deallocation list if their associated condition is
1051 /// always 'false'.
1052 ///
1053 /// Example:
1054 /// ```
1055 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1056 /// if (%arg2, %false)
1057 /// ```
1058 /// becomes
1059 /// ```
1060 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1061 /// ```
1062 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1064 
1065  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1066  PatternRewriter &rewriter) const override {
1067  SmallVector<Value> newMemrefs, newConditions;
1068  for (auto [memref, cond] :
1069  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1070  if (!matchPattern(cond, m_Zero())) {
1071  newMemrefs.push_back(memref);
1072  newConditions.push_back(cond);
1073  }
1074  }
1075 
1076  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1077  rewriter);
1078  }
1079 };
1080 
1081 /// The `memref.extract_strided_metadata` is often inserted to get the base
1082 /// memref if the operand is not already guaranteed to be the result of a memref
1083 /// allocation operation. This canonicalization pattern removes this extraction
1084 /// operation if the operand is now produced by an allocation operation (e.g.,
1085 /// due to other canonicalizations simplifying the IR).
1086 ///
1087 /// Example:
1088 /// ```mlir
1089 /// %alloc = memref.alloc() : memref<2xi32>
1090 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1091 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1092 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1093 /// ```
1094 /// is canonicalized to
1095 /// ```mlir
1096 /// %alloc = memref.alloc() : memref<2xi32>
1097 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1098 /// ```
1099 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1101 
1102  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1103  PatternRewriter &rewriter) const override {
1104  SmallVector<Value> newMemrefs(
1105  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1106  auto extractStridedOp =
1107  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1108  if (!extractStridedOp)
1109  return memref;
1110  Value allocMemref = extractStridedOp.getOperand();
1111  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1112  if (!allocOp)
1113  return memref;
1114  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1115  return allocMemref;
1116  return memref;
1117  }));
1118 
1119  return updateDeallocIfChanged(deallocOp, newMemrefs,
1120  deallocOp.getConditions(), rewriter);
1121  }
1122 };
1123 
1124 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1125 /// other user of the allocated value and the allocating operation can be safely
1126 /// removed. If the same value is present multiple times, this pattern relies on
1127 /// other canonicalization patterns to remove the duplicate first.
1128 ///
1129 /// Example:
1130 /// ```mlir
1131 /// %alloc = memref.alloc() : memref<2xi32>
1132 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1133 /// ```
1134 /// is canonicalized to
1135 /// ```mlir
1136 /// bufferization.dealloc (%arg0 : ...) if (%true)
1137 /// ```
1138 struct RemoveAllocDeallocPairWhenNoOtherUsers
1139  : public OpRewritePattern<DeallocOp> {
1141 
1142  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1143  PatternRewriter &rewriter) const override {
1144  SmallVector<Value> newMemrefs, newConditions;
1145  SmallVector<Operation *> toDelete;
1146  for (auto [memref, cond] :
1147  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1148  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1149  // Check that it is indeed an allocate effect, that the op has no other
1150  // side effects (which would not allow us to remove the op), and that
1151  // there are no other users.
1152  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1153  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1154  memref.hasOneUse()) {
1155  toDelete.push_back(allocOp);
1156  continue;
1157  }
1158  }
1159 
1160  newMemrefs.push_back(memref);
1161  newConditions.push_back(cond);
1162  }
1163 
1164  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1165  rewriter)))
1166  return failure();
1167 
1168  for (Operation *op : toDelete)
1169  rewriter.eraseOp(op);
1170 
1171  return success();
1172  }
1173 };
1174 
1175 } // anonymous namespace
1176 
1177 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1178  MLIRContext *context) {
1180 }
1181 
1183  RewritePatternSet &patterns, MLIRContext *context) {
1184  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1185  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1186  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1187  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // TableGen'd op method definitions
1192 //===----------------------------------------------------------------------===//
1193 
1194 #define GET_OP_CLASSES
1195 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:161
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:98
IndexType getIndexType()
Definition: Builders.cpp:53
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.