MLIR  21.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type and rank must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getRank() != destType.getRank())
35  return failure();
36 
37  // In case the affine maps are different, we may need to use a copy if we go
38  // from dynamic to static offset or stride (the canonicalization cannot know
39  // at this point that it is really cast compatible).
40  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41  int64_t sourceOffset, targetOffset;
42  SmallVector<int64_t, 4> sourceStrides, targetStrides;
43  if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
44  failed(target.getStridesAndOffset(targetStrides, targetOffset)))
45  return false;
46  auto dynamicToStatic = [](int64_t a, int64_t b) {
47  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
48  };
49  if (dynamicToStatic(sourceOffset, targetOffset))
50  return false;
51  for (auto it : zip(sourceStrides, targetStrides))
52  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
53  return false;
54  return true;
55  };
56 
57  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
58  // ensure that we only generate casts that always succeed at runtime, we check
59  // a fix extra conditions in `isGuaranteedCastCompatible`.
60  if (memref::CastOp::areCastCompatible(srcType, destType) &&
61  isGuaranteedCastCompatible(srcType, destType)) {
62  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
63  return casted;
64  }
65 
66  auto loc = value.getLoc();
67  SmallVector<Value, 4> dynamicOperands;
68  for (int i = 0; i < destType.getRank(); ++i) {
69  if (destType.getShape()[i] != ShapedType::kDynamic)
70  continue;
71  Value size = b.create<memref::DimOp>(loc, value, i);
72  dynamicOperands.push_back(size);
73  }
74 
75  FailureOr<Value> copy =
76  options.createAlloc(b, loc, destType, dynamicOperands);
77  if (failed(copy))
78  return failure();
79  if (failed(options.createMemCpy(b, loc, value, *copy)))
80  return failure();
81  return copy;
82 }
83 
84 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
85 /// to_memref op are different, a memref.cast is needed.
87  RewriterBase &rewriter, ToMemrefOp toMemref,
89  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
90  if (!memrefToTensor)
91  return failure();
92 
93  Type srcType = memrefToTensor.getMemref().getType();
94  Type destType = toMemref.getType();
95 
96  // Directly rewrite if the type did not change.
97  if (srcType == destType) {
98  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
99  return success();
100  }
101 
102  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
103  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
104  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
105 
106  // Ranked memref -> Ranked memref cast.
107  if (rankedSrcType && rankedDestType) {
108  FailureOr<Value> replacement = castOrReallocMemRefValue(
109  rewriter, memrefToTensor.getMemref(), rankedDestType, options);
110  if (failed(replacement))
111  return failure();
112 
113  rewriter.replaceOp(toMemref, *replacement);
114  return success();
115  }
116 
117  // Unranked memref -> Ranked memref cast: May require a copy.
118  // TODO: Not implemented at the moment.
119  if (unrankedSrcType && rankedDestType)
120  return failure();
121 
122  // Unranked memref -> unranked memref cast
123  // Ranked memref -> unranked memref cast: No copy needed.
124  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125  "expected that types are cast compatible");
126  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
127  memrefToTensor.getMemref());
128  return success();
129 }
130 
132  OpBuilder &b, Location loc, Value shapedValue,
133  SmallVector<Value> &dynamicDims) {
134  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
135  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
136  if (shapedType.isDynamicDim(i)) {
137  if (llvm::isa<MemRefType>(shapedType)) {
138  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
139  } else {
140  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
141  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
142  }
143  }
144  }
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // AllocTensorOp
149 //===----------------------------------------------------------------------===//
150 
151 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152  const BufferizationOptions &options) {
153  OpBuilder::InsertionGuard g(rewriter);
154  Location loc = getLoc();
155 
156  // Nothing to do for dead AllocTensorOps.
157  if (getOperation()->getUses().empty()) {
158  rewriter.eraseOp(getOperation());
159  return success();
160  }
161 
162  // Get "copy" buffer.
163  Value copyBuffer;
164  if (getCopy()) {
165  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
166  if (failed(maybeCopyBuffer))
167  return failure();
168  copyBuffer = *maybeCopyBuffer;
169  }
170 
171  // Create memory allocation.
172  auto allocType = bufferization::getBufferType(getResult(), options);
173  if (failed(allocType))
174  return failure();
175  SmallVector<Value> dynamicDims = getDynamicSizes();
176  if (getCopy()) {
177  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
178  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
179  }
180  FailureOr<Value> alloc = options.createAlloc(
181  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
182  if (failed(alloc))
183  return failure();
184 
185  // Create memory copy (if any).
186  if (getCopy()) {
187  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
188  return failure();
189  }
190 
191  // Replace op.
192  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
193 
194  return success();
195 }
196 
197 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
198  const AnalysisState &state) {
199  // AllocTensorOps do not write unless they have a `copy` value.
200  return static_cast<bool>(getCopy());
201 }
202 
203 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
204  const AnalysisState &state) {
205  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
206  "expected copy operand");
207  return true;
208 }
209 
210 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
211  const AnalysisState &state) {
212  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
213  "expected copy operand");
214  return false;
215 }
216 
217 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
218  const AnalysisState &state) {
219  // This is a new allocation. It does not alias with any other buffer.
220  return {};
221 }
222 
223 FailureOr<BaseMemRefType>
225  SmallVector<Value> &invocationStack) {
226  assert(value == getResult() && "invalid value");
227 
228  // Compute memory space of this allocation.
229  Attribute memorySpace;
230  if (getMemorySpace().has_value()) {
231  memorySpace = *getMemorySpace();
232  } else if (getCopy()) {
233  auto copyBufferType =
234  bufferization::getBufferType(getCopy(), options, invocationStack);
235  if (failed(copyBufferType))
236  return failure();
237  memorySpace = copyBufferType->getMemorySpace();
238  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
239  memorySpace = *ms;
240  } else {
241  return getOperation()->emitError("could not infer memory space");
242  }
243 
244  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
245 }
246 
247 LogicalResult AllocTensorOp::verify() {
248  if (getCopy() && !getDynamicSizes().empty())
249  return emitError("dynamic sizes not needed when copying a tensor");
250  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
251  return emitError("expected ")
252  << getType().getNumDynamicDims() << " dynamic sizes";
253  if (getCopy() && getCopy().getType() != getType())
254  return emitError("expected that `copy` and return type match");
255  return success();
256 }
257 
258 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
259  RankedTensorType type, ValueRange dynamicSizes) {
260  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
261  /*size_hint=*/Value(),
262  /*memory_space=*/IntegerAttr());
263 }
264 
265 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
266  RankedTensorType type, ValueRange dynamicSizes,
267  Value copy) {
268  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
269  /*memory_space=*/IntegerAttr());
270 }
271 
272 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
273  TensorType type, ValueRange dynamicSizes, Value copy,
274  IntegerAttr memorySpace) {
275  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
276  memorySpace);
277 }
278 
279 namespace {
280 /// Change the type of the result of a `bufferization.alloc_tensor` by making
281 /// the result type statically sized along dimension that in the original
282 /// operation where defined as dynamic, but the size was defined using a
283 /// `constant` op. For example:
284 ///
285 /// %c5 = arith.constant 5: index
286 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
287 ///
288 /// to
289 ///
290 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
291 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
293 
294  LogicalResult matchAndRewrite(AllocTensorOp op,
295  PatternRewriter &rewriter) const override {
296  if (op.getCopy())
297  return failure();
298  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
299  SmallVector<Value> newDynamicSizes;
300  unsigned int dynValCounter = 0;
301  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
302  if (!op.isDynamicDim(i))
303  continue;
304  Value value = op.getDynamicSizes()[dynValCounter++];
305  APInt intVal;
306  if (matchPattern(value, m_ConstantInt(&intVal))) {
307  int64_t dim = intVal.getSExtValue();
308  if (dim >= 0)
309  newShape[i] = intVal.getSExtValue();
310  else
311  newDynamicSizes.push_back(value);
312  } else {
313  newDynamicSizes.push_back(value);
314  }
315  }
316  RankedTensorType newType = RankedTensorType::get(
317  newShape, op.getType().getElementType(), op.getType().getEncoding());
318  if (newType == op.getType())
319  return failure();
320  auto newOp = rewriter.create<AllocTensorOp>(
321  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
322  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
323  return success();
324  }
325 };
326 
327 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
329 
330  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
331  PatternRewriter &rewriter) const override {
332  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
333  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
334  if (!allocTensorOp || !maybeConstantIndex)
335  return failure();
336  if (*maybeConstantIndex < 0 ||
337  *maybeConstantIndex >= allocTensorOp.getType().getRank())
338  return failure();
339  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
340  return failure();
341  rewriter.replaceOp(
342  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
343  return success();
344  }
345 };
346 } // namespace
347 
348 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
349  MLIRContext *ctx) {
350  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
351 }
352 
354  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
355  auto shapes = llvm::to_vector<4>(
356  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
357  [&](int64_t dim) -> OpFoldResult {
358  if (isDynamicDim(dim))
359  return getDynamicSize(builder, dim);
360  return builder.getIndexAttr(getStaticSize(dim));
361  }));
362  reifiedReturnShapes.emplace_back(std::move(shapes));
363  return success();
364 }
365 
366 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
367  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
368  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
369  parser.parseRParen())
370  return failure();
371  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
372  OpAsmParser::UnresolvedOperand copyOperand;
373  if (copyKeyword.succeeded())
374  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
375  parser.parseRParen())
376  return failure();
377  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
378  OpAsmParser::UnresolvedOperand sizeHintOperand;
379  if (sizeHintKeyword.succeeded())
380  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
381  return failure();
382  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
383  return failure();
384 
385  TensorType type;
386  if (parser.parseCustomTypeWithFallback(type))
387  return failure();
388  result.addTypes(type);
389 
390  Type indexType = parser.getBuilder().getIndexType();
391  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
392  return failure();
393  if (copyKeyword.succeeded())
394  if (parser.resolveOperand(copyOperand, type, result.operands))
395  return failure();
396  if (sizeHintKeyword.succeeded())
397  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
398  return failure();
399  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
401  {static_cast<int32_t>(dynamicSizesOperands.size()),
402  static_cast<int32_t>(copyKeyword.succeeded()),
403  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
404  return success();
405 }
406 
408  p << "(" << getDynamicSizes() << ")";
409  if (getCopy())
410  p << " copy(" << getCopy() << ")";
411  if (getSizeHint())
412  p << " size_hint=" << getSizeHint();
413  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
414  AllocTensorOp::getOperandSegmentSizeAttr()});
415  p << " : ";
416  auto type = getResult().getType();
417  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
418  p.printStrippedAttrOrType(validType);
419  else
420  p << type;
421 }
422 
423 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
424  assert(isDynamicDim(idx) && "expected dynamic dim");
425  if (getCopy())
426  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
427  return getOperand(getIndexOfDynamicSize(idx));
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // CloneOp
432 //===----------------------------------------------------------------------===//
433 
434 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
435  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
436 }
437 
438 namespace {
439 
440 /// Merge the clone and its source (by converting the clone to a cast) when
441 /// possible.
442 struct SimplifyClones : public OpRewritePattern<CloneOp> {
444 
445  LogicalResult matchAndRewrite(CloneOp cloneOp,
446  PatternRewriter &rewriter) const override {
447  if (cloneOp.use_empty()) {
448  rewriter.eraseOp(cloneOp);
449  return success();
450  }
451 
452  Value source = cloneOp.getInput();
453  if (source.getType() != cloneOp.getType() &&
454  !memref::CastOp::areCastCompatible({source.getType()},
455  {cloneOp.getType()}))
456  return failure();
457 
458  // Aims to find the dealloc op for the canonical source
459  // which otherwise could prevent removal of unnecessary allocs.
460  Value canonicalSource = source;
461  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
462  canonicalSource.getDefiningOp()))
463  canonicalSource = iface.getViewSource();
464 
465  std::optional<Operation *> maybeCloneDeallocOp =
466  memref::findDealloc(cloneOp.getOutput());
467  // Skip if either of them has > 1 deallocate operations.
468  if (!maybeCloneDeallocOp.has_value())
469  return failure();
470  std::optional<Operation *> maybeSourceDeallocOp =
471  memref::findDealloc(canonicalSource);
472  if (!maybeSourceDeallocOp.has_value())
473  return failure();
474  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
475  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
476 
477  // If both are deallocated in the same block, their in-block lifetimes
478  // might not fully overlap, so we cannot decide which one to drop.
479  if (cloneDeallocOp && sourceDeallocOp &&
480  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
481  return failure();
482 
483  Block *currentBlock = cloneOp->getBlock();
484  Operation *redundantDealloc = nullptr;
485  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
486  redundantDealloc = cloneDeallocOp;
487  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
488  redundantDealloc = sourceDeallocOp;
489  }
490 
491  if (!redundantDealloc)
492  return failure();
493 
494  // Safety check that there are no other deallocations inbetween
495  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
496  // of source before the uses of the clone. With alias information, we could
497  // restrict this to only fail of the dealloc's operand is an alias
498  // of the source.
499  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
500  pos = pos->getNextNode()) {
501  // Bail if we run out of operations while looking for a deallocation op.
502  if (!pos)
503  return failure();
504  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
505  if (!effectInterface)
506  continue;
507  if (effectInterface.hasEffect<MemoryEffects::Free>())
508  return failure();
509  }
510 
511  if (source.getType() != cloneOp.getType())
512  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
513  cloneOp.getType(), source);
514  rewriter.replaceOp(cloneOp, source);
515  rewriter.eraseOp(redundantDealloc);
516  return success();
517  }
518 };
519 
520 } // namespace
521 
522 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
523  MLIRContext *context) {
524  results.add<SimplifyClones>(context);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // DeallocTensorOp
529 //===----------------------------------------------------------------------===//
530 
531 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
532  const BufferizationOptions &options) {
533  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
534  if (failed(buffer))
535  return failure();
536  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
537  rewriter.eraseOp(getOperation());
538  return success();
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // MaterializeInDestinationOp
543 //===----------------------------------------------------------------------===//
544 
545 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
546  OpOperand &opOperand, const AnalysisState &state) {
547  return opOperand == getSourceMutable();
548 }
549 
550 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
551  OpOperand &opOperand, const AnalysisState &state) {
552  if (opOperand == getDestMutable()) {
553  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
554  return true;
555  }
556  return false;
557 }
558 
559 bool MaterializeInDestinationOp::mustBufferizeInPlace(
560  OpOperand &opOperand, const AnalysisState &state) {
561  // The source is only read and not written, so it always bufferizes in-place
562  // by default. The destination is written and is forced to bufferize in-place
563  // (if it is a tensor).
564  return true;
565 }
566 
568 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
569  const AnalysisState &state) {
570  if (opOperand == getDestMutable()) {
571  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
572  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
573  }
574  return {};
575 }
576 
577 LogicalResult
578 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
579  const BufferizationOptions &options) {
580  bool tensorDest = isa<TensorType>(getDest().getType());
581  Value buffer;
582  if (tensorDest) {
583  FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
584  if (failed(maybeBuffer))
585  return failure();
586  buffer = *maybeBuffer;
587  } else {
588  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
589  buffer = getDest();
590  }
591  auto srcBuffer = getBuffer(rewriter, getSource(), options);
592  if (failed(srcBuffer))
593  return failure();
594  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
595  return failure();
596  replaceOpWithBufferizedValues(rewriter, getOperation(),
597  tensorDest ? ValueRange(buffer) : ValueRange());
598  return success();
599 }
600 
601 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
602  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
603  // As elements are copied from the "source" buffer to the "dest" buffer,
604  // already copied elements are not read a second time.
605  return true;
606 }
607 
609  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
610  if (getOperation()->getNumResults() == 1) {
611  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
612  reifiedReturnShapes.resize(1,
613  SmallVector<OpFoldResult>(getType().getRank()));
614  reifiedReturnShapes[0] =
615  tensor::getMixedSizes(builder, getLoc(), getDest());
616  }
617  return success();
618 }
619 
621  Location loc) {
622  if (isa<TensorType>(getDest().getType())) {
623  // The subset is the entire destination tensor.
624  return getDest();
625  }
626 
627  // The "restrict" attribute is transferred from this op to the newly created
628  // to_tensor op. If this op does not the "restrict" attribute, the subset
629  // extraction cannot be built because there is no guarantee that there is no
630  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
631  if (!getRestrict())
632  return {};
633 
634  // Build a bufferization.to_tensor op.
635  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
636  assert(getRestrict() &&
637  "expected that ops with memrefs dest have 'restrict'");
638  setRestrict(false);
639  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
640  getWritable());
641 }
642 
643 bool MaterializeInDestinationOp::isEquivalentSubset(
644  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
645  return equivalenceFn(getDest(), candidate);
646 }
647 
649 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
650  return {getDest()};
651 }
652 
653 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
654  return getOperation()->getOpOperand(0) /*source*/;
655 }
656 
657 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
658  SubsetOpInterface subsetOp,
659  function_ref<bool(Value, Value)> equivalenceFn) {
660  return false;
661 }
662 
663 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
664  SubsetOpInterface subsetOp,
665  function_ref<bool(Value, Value)> equivalenceFn) {
666  return false;
667 }
668 
669 LogicalResult MaterializeInDestinationOp::verify() {
670  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
671  return emitOpError("'dest' must be a tensor or a memref");
672  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
673  if (getOperation()->getNumResults() != 1)
674  return emitOpError("tensor 'dest' implies exactly one tensor result");
675  if (destType != getResult().getType())
676  return emitOpError("result and 'dest' types must match");
677  }
678  if (isa<BaseMemRefType>(getDest().getType()) &&
679  getOperation()->getNumResults() != 0)
680  return emitOpError("memref 'dest' implies zero results");
681  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
682  return emitOpError("'restrict' is valid only for memref destinations");
683  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
684  return emitOpError("'writable' must be specified if and only if the "
685  "destination is of memref type");
686  TensorType srcType = getSource().getType();
687  ShapedType destType = cast<ShapedType>(getDest().getType());
688  if (srcType.hasRank() != destType.hasRank())
689  return emitOpError("source/destination shapes are incompatible");
690  if (srcType.hasRank()) {
691  if (srcType.getRank() != destType.getRank())
692  return emitOpError("rank mismatch between source and destination shape");
693  for (auto [src, dest] :
694  llvm::zip(srcType.getShape(), destType.getShape())) {
695  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
696  // Cannot verify dynamic dimension size. Assume that that they match at
697  // runtime.
698  continue;
699  }
700  if (src != dest)
701  return emitOpError("source/destination shapes are incompatible");
702  }
703  }
704  return success();
705 }
706 
707 void MaterializeInDestinationOp::build(OpBuilder &builder,
708  OperationState &state, Value source,
709  Value dest) {
710  auto destTensorType = dyn_cast<TensorType>(dest.getType());
711  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
712  source, dest);
713 }
714 
715 bool MaterializeInDestinationOp::isWritable(Value value,
716  const AnalysisState &state) {
717  return isa<TensorType>(getDest().getType()) ? true : getWritable();
718 }
719 
720 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
721  return getDestMutable();
722 }
723 
724 void MaterializeInDestinationOp::getEffects(
726  &effects) {
727  if (isa<BaseMemRefType>(getDest().getType()))
728  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // ToTensorOp
734 //===----------------------------------------------------------------------===//
735 
736 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
737  return getWritable();
738 }
739 
740 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
741  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
742  // Approximate alias analysis by conservatively folding only when no there
743  // is no interleaved operation.
744  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
745  toMemref->getNextNode() == this->getOperation())
746  return toMemref.getTensor();
747  return {};
748 }
749 
750 namespace {
751 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
753 
754  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
755  PatternRewriter &rewriter) const override {
756  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
757  if (!memrefToTensorOp)
758  return failure();
759 
760  rewriter.replaceOpWithNewOp<memref::DimOp>(
761  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
762  return success();
763  }
764 };
765 } // namespace
766 
767 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
768  MLIRContext *context) {
769  results.add<DimOfToTensorFolder>(context);
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // ToMemrefOp
774 //===----------------------------------------------------------------------===//
775 
776 OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
777  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
778  if (memrefToTensor.getMemref().getType() == getType())
779  return memrefToTensor.getMemref();
780  return {};
781 }
782 
783 namespace {
784 
785 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
786 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
788 
789  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
790  PatternRewriter &rewriter) const final {
791  auto tensorCastOperand =
792  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
793  if (!tensorCastOperand)
794  return failure();
795  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
796  tensorCastOperand.getOperand().getType());
797  if (!srcTensorType)
798  return failure();
799  auto memrefType = MemRefType::get(srcTensorType.getShape(),
800  srcTensorType.getElementType());
801  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
802  tensorCastOperand.getOperand());
803  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
804  memref);
805  return success();
806  }
807 };
808 
809 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
810 /// cast if necessary.
811 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
813 
814  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
815  PatternRewriter &rewriter) const final {
817  options.bufferAlignment = 0;
818  return foldToMemrefToTensorPair(rewriter, toMemref, options);
819  }
820 };
821 
822 /// Fold a load on a to_memref operation into an tensor.extract on the
823 /// corresponding tensor.
824 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
826 
827  LogicalResult matchAndRewrite(memref::LoadOp load,
828  PatternRewriter &rewriter) const override {
829  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
830  if (!toMemref)
831  return failure();
832 
833  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
834  load.getIndices());
835  return success();
836  }
837 };
838 
839 /// Fold dim of a to_memref into the dim of the tensor.
840 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
842 
843  LogicalResult matchAndRewrite(memref::DimOp dimOp,
844  PatternRewriter &rewriter) const override {
845  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
846  if (!castOp)
847  return failure();
848  Value newSource = castOp.getOperand();
849  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
850  dimOp.getIndex());
851  return success();
852  }
853 };
854 
855 } // namespace
856 
857 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
858  MLIRContext *context) {
859  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
860  ToMemrefToTensorFolding>(context);
861 }
862 
863 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
864  const BufferizationOptions &options) {
865  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
866  (void)foldToMemrefToTensorPair(rewriter, *this, options);
867  // Note: The return value of `bufferize` indicates whether there was an error
868  // or not. (And not whether the pattern matched or not.)
869  return success();
870 }
871 
872 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
873  Value alloc) {
874  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
875  .getOperation();
876 }
877 
878 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
879  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // DeallocOp
884 //===----------------------------------------------------------------------===//
885 
886 LogicalResult DeallocOp::inferReturnTypes(
887  MLIRContext *context, std::optional<::mlir::Location> location,
888  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
889  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
890  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
891  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
892  IntegerType::get(context, 1));
893  return success();
894 }
895 
896 LogicalResult DeallocOp::verify() {
897  if (getMemrefs().size() != getConditions().size())
898  return emitOpError(
899  "must have the same number of conditions as memrefs to deallocate");
900  if (getRetained().size() != getUpdatedConditions().size())
901  return emitOpError("must have the same number of updated conditions "
902  "(results) as retained operands");
903  return success();
904 }
905 
906 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
907  ValueRange memrefs,
908  ValueRange conditions,
909  PatternRewriter &rewriter) {
910  if (deallocOp.getMemrefs() == memrefs &&
911  deallocOp.getConditions() == conditions)
912  return failure();
913 
914  rewriter.modifyOpInPlace(deallocOp, [&]() {
915  deallocOp.getMemrefsMutable().assign(memrefs);
916  deallocOp.getConditionsMutable().assign(conditions);
917  });
918  return success();
919 }
920 
921 namespace {
922 
923 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
924 /// make sure the corresponding condition value is updated accordingly since
925 /// their two conditions might not cover the same set of cases. In that case, we
926 /// have to combine them (by computing the disjunction of them).
927 /// Example:
928 /// ```mlir
929 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
930 /// ```
931 /// is canonicalized to
932 /// ```mlir
933 /// %0 = arith.ori %arg1, %arg2 : i1
934 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
935 /// ```
936 struct DeallocRemoveDuplicateDeallocMemrefs
937  : public OpRewritePattern<DeallocOp> {
939 
940  LogicalResult matchAndRewrite(DeallocOp deallocOp,
941  PatternRewriter &rewriter) const override {
942  // Unique memrefs to be deallocated.
943  DenseMap<Value, unsigned> memrefToCondition;
944  SmallVector<Value> newMemrefs, newConditions;
945  for (auto [i, memref, cond] :
946  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
947  if (memrefToCondition.count(memref)) {
948  // If the dealloc conditions don't match, we need to make sure that the
949  // dealloc happens on the union of cases.
950  Value &newCond = newConditions[memrefToCondition[memref]];
951  if (newCond != cond)
952  newCond =
953  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
954  } else {
955  memrefToCondition.insert({memref, newConditions.size()});
956  newMemrefs.push_back(memref);
957  newConditions.push_back(cond);
958  }
959  }
960 
961  // Return failure if we don't change anything such that we don't run into an
962  // infinite loop of pattern applications.
963  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
964  rewriter);
965  }
966 };
967 
968 /// Remove duplicate values in the list of retained memrefs. We need to make
969 /// sure the corresponding result condition value is replaced properly.
970 /// Example:
971 /// ```mlir
972 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
973 /// ```
974 /// is canonicalized to
975 /// ```mlir
976 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
977 /// ```
978 struct DeallocRemoveDuplicateRetainedMemrefs
979  : public OpRewritePattern<DeallocOp> {
981 
982  LogicalResult matchAndRewrite(DeallocOp deallocOp,
983  PatternRewriter &rewriter) const override {
984  // Unique retained values
986  SmallVector<Value> newRetained;
987  SmallVector<unsigned> resultReplacementIdx;
988  unsigned i = 0;
989  for (auto retained : deallocOp.getRetained()) {
990  if (seen.count(retained)) {
991  resultReplacementIdx.push_back(seen[retained]);
992  continue;
993  }
994 
995  seen[retained] = i;
996  newRetained.push_back(retained);
997  resultReplacementIdx.push_back(i++);
998  }
999 
1000  // Return failure if we don't change anything such that we don't run into an
1001  // infinite loop of pattern applications.
1002  if (newRetained.size() == deallocOp.getRetained().size())
1003  return failure();
1004 
1005  // We need to create a new op because the number of results is always the
1006  // same as the number of condition operands.
1007  auto newDeallocOp =
1008  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1009  deallocOp.getConditions(), newRetained);
1010  SmallVector<Value> replacements(
1011  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1012  return newDeallocOp.getUpdatedConditions()[idx];
1013  }));
1014  rewriter.replaceOp(deallocOp, replacements);
1015  return success();
1016  }
1017 };
1018 
1019 /// Erase deallocation operations where the variadic list of memrefs to
1020 /// deallocate is empty. Example:
1021 /// ```mlir
1022 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1023 /// ```
1024 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1026 
1027  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1028  PatternRewriter &rewriter) const override {
1029  if (deallocOp.getMemrefs().empty()) {
1030  Value constFalse = rewriter.create<arith::ConstantOp>(
1031  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1032  rewriter.replaceOp(
1033  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1034  constFalse));
1035  return success();
1036  }
1037  return failure();
1038  }
1039 };
1040 
1041 /// Removes memrefs from the deallocation list if their associated condition is
1042 /// always 'false'.
1043 ///
1044 /// Example:
1045 /// ```
1046 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1047 /// if (%arg2, %false)
1048 /// ```
1049 /// becomes
1050 /// ```
1051 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1052 /// ```
1053 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1055 
1056  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1057  PatternRewriter &rewriter) const override {
1058  SmallVector<Value> newMemrefs, newConditions;
1059  for (auto [memref, cond] :
1060  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1061  if (!matchPattern(cond, m_Zero())) {
1062  newMemrefs.push_back(memref);
1063  newConditions.push_back(cond);
1064  }
1065  }
1066 
1067  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1068  rewriter);
1069  }
1070 };
1071 
1072 /// The `memref.extract_strided_metadata` is often inserted to get the base
1073 /// memref if the operand is not already guaranteed to be the result of a memref
1074 /// allocation operation. This canonicalization pattern removes this extraction
1075 /// operation if the operand is now produced by an allocation operation (e.g.,
1076 /// due to other canonicalizations simplifying the IR).
1077 ///
1078 /// Example:
1079 /// ```mlir
1080 /// %alloc = memref.alloc() : memref<2xi32>
1081 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1082 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1083 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1084 /// ```
1085 /// is canonicalized to
1086 /// ```mlir
1087 /// %alloc = memref.alloc() : memref<2xi32>
1088 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1089 /// ```
1090 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1092 
1093  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1094  PatternRewriter &rewriter) const override {
1095  SmallVector<Value> newMemrefs(
1096  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1097  auto extractStridedOp =
1098  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1099  if (!extractStridedOp)
1100  return memref;
1101  Value allocMemref = extractStridedOp.getOperand();
1102  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1103  if (!allocOp)
1104  return memref;
1105  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1106  return allocMemref;
1107  return memref;
1108  }));
1109 
1110  return updateDeallocIfChanged(deallocOp, newMemrefs,
1111  deallocOp.getConditions(), rewriter);
1112  }
1113 };
1114 
1115 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1116 /// other user of the allocated value and the allocating operation can be safely
1117 /// removed. If the same value is present multiple times, this pattern relies on
1118 /// other canonicalization patterns to remove the duplicate first.
1119 ///
1120 /// Example:
1121 /// ```mlir
1122 /// %alloc = memref.alloc() : memref<2xi32>
1123 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1124 /// ```
1125 /// is canonicalized to
1126 /// ```mlir
1127 /// bufferization.dealloc (%arg0 : ...) if (%true)
1128 /// ```
1129 struct RemoveAllocDeallocPairWhenNoOtherUsers
1130  : public OpRewritePattern<DeallocOp> {
1132 
1133  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1134  PatternRewriter &rewriter) const override {
1135  SmallVector<Value> newMemrefs, newConditions;
1136  SmallVector<Operation *> toDelete;
1137  for (auto [memref, cond] :
1138  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1139  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1140  // Check that it is indeed an allocate effect, that the op has no other
1141  // side effects (which would not allow us to remove the op), and that
1142  // there are no other users.
1143  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1144  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1145  memref.hasOneUse()) {
1146  toDelete.push_back(allocOp);
1147  continue;
1148  }
1149  }
1150 
1151  newMemrefs.push_back(memref);
1152  newConditions.push_back(cond);
1153  }
1154 
1155  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1156  rewriter)))
1157  return failure();
1158 
1159  for (Operation *op : toDelete)
1160  rewriter.eraseOp(op);
1161 
1162  return success();
1163  }
1164 };
1165 
1166 } // anonymous namespace
1167 
1168 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1169  MLIRContext *context) {
1171 }
1172 
1174  RewritePatternSet &patterns, MLIRContext *context) {
1175  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1176  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1177  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1178  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // TableGen'd op method definitions
1183 //===----------------------------------------------------------------------===//
1184 
1185 #define GET_OP_CLASSES
1186 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.