MLIR  21.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Matchers.h"
17 #include <optional>
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
27  OpBuilder &b, Value value, MemRefType destType,
29  auto srcType = llvm::cast<MemRefType>(value.getType());
30 
31  // Element type and rank must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getRank() != destType.getRank())
35  return failure();
36 
37  // In case the affine maps are different, we may need to use a copy if we go
38  // from dynamic to static offset or stride (the canonicalization cannot know
39  // at this point that it is really cast compatible).
40  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41  int64_t sourceOffset, targetOffset;
42  SmallVector<int64_t, 4> sourceStrides, targetStrides;
43  if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
44  failed(target.getStridesAndOffset(targetStrides, targetOffset)))
45  return false;
46  auto dynamicToStatic = [](int64_t a, int64_t b) {
47  return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
48  };
49  if (dynamicToStatic(sourceOffset, targetOffset))
50  return false;
51  for (auto it : zip(sourceStrides, targetStrides))
52  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
53  return false;
54  return true;
55  };
56 
57  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
58  // ensure that we only generate casts that always succeed at runtime, we check
59  // a fix extra conditions in `isGuaranteedCastCompatible`.
60  if (memref::CastOp::areCastCompatible(srcType, destType) &&
61  isGuaranteedCastCompatible(srcType, destType)) {
62  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
63  return casted;
64  }
65 
66  auto loc = value.getLoc();
67  SmallVector<Value, 4> dynamicOperands;
68  for (int i = 0; i < destType.getRank(); ++i) {
69  if (destType.getShape()[i] != ShapedType::kDynamic)
70  continue;
71  Value size = b.create<memref::DimOp>(loc, value, i);
72  dynamicOperands.push_back(size);
73  }
74 
75  FailureOr<Value> copy =
76  options.createAlloc(b, loc, destType, dynamicOperands);
77  if (failed(copy))
78  return failure();
79  if (failed(options.createMemCpy(b, loc, value, *copy)))
80  return failure();
81  return copy;
82 }
83 
84 /// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
85 /// to_buffer op are different, a memref.cast is needed.
87  RewriterBase &rewriter, ToBufferOp toBuffer,
89  auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
90  if (!bufferToTensor)
91  return failure();
92 
93  Type srcType = bufferToTensor.getMemref().getType();
94  Type destType = toBuffer.getType();
95 
96  // Directly rewrite if the type did not change.
97  if (srcType == destType) {
98  rewriter.replaceOp(toBuffer, bufferToTensor.getMemref());
99  return success();
100  }
101 
102  auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
103  auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
104  auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
105 
106  // Ranked memref -> Ranked memref cast.
107  if (rankedSrcType && rankedDestType) {
108  FailureOr<Value> replacement = castOrReallocMemRefValue(
109  rewriter, bufferToTensor.getMemref(), rankedDestType, options);
110  if (failed(replacement))
111  return failure();
112 
113  rewriter.replaceOp(toBuffer, *replacement);
114  return success();
115  }
116 
117  // Unranked memref -> Ranked memref cast: May require a copy.
118  // TODO: Not implemented at the moment.
119  if (unrankedSrcType && rankedDestType)
120  return failure();
121 
122  // Unranked memref -> unranked memref cast
123  // Ranked memref -> unranked memref cast: No copy needed.
124  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125  "expected that types are cast compatible");
126  rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
127  bufferToTensor.getMemref());
128  return success();
129 }
130 
132  OpBuilder &b, Location loc, Value shapedValue,
133  SmallVector<Value> &dynamicDims) {
134  auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
135  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
136  if (shapedType.isDynamicDim(i)) {
137  if (llvm::isa<MemRefType>(shapedType)) {
138  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
139  } else {
140  assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
141  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
142  }
143  }
144  }
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // AllocTensorOp
149 //===----------------------------------------------------------------------===//
150 
151 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153  BufferizationState &state) {
154  OpBuilder::InsertionGuard g(rewriter);
155  Location loc = getLoc();
156 
157  // Nothing to do for dead AllocTensorOps.
158  if (getOperation()->getUses().empty()) {
159  rewriter.eraseOp(getOperation());
160  return success();
161  }
162 
163  // Get "copy" buffer.
164  Value copyBuffer;
165  if (getCopy()) {
166  FailureOr<Value> maybeCopyBuffer =
167  getBuffer(rewriter, getCopy(), options, state);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
172 
173  // Create memory allocation.
174  auto allocType = bufferization::getBufferType(getResult(), options, state);
175  if (failed(allocType))
176  return failure();
177  SmallVector<Value> dynamicDims = getDynamicSizes();
178  if (getCopy()) {
179  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181  }
182  FailureOr<Value> alloc = options.createAlloc(
183  rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184  if (failed(alloc))
185  return failure();
186 
187  // Create memory copy (if any).
188  if (getCopy()) {
189  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190  return failure();
191  }
192 
193  // Replace op.
194  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195 
196  return success();
197 }
198 
199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200  const AnalysisState &state) {
201  // AllocTensorOps do not write unless they have a `copy` value.
202  return static_cast<bool>(getCopy());
203 }
204 
205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206  const AnalysisState &state) {
207  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208  "expected copy operand");
209  return true;
210 }
211 
212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213  const AnalysisState &state) {
214  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215  "expected copy operand");
216  return false;
217 }
218 
219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220  const AnalysisState &state) {
221  // This is a new allocation. It does not alias with any other buffer.
222  return {};
223 }
224 
225 FailureOr<BaseMemRefType>
227  const BufferizationState &state,
228  SmallVector<Value> &invocationStack) {
229  assert(value == getResult() && "invalid value");
230 
231  // Compute memory space of this allocation.
232  Attribute memorySpace;
233  if (getMemorySpace().has_value()) {
234  memorySpace = *getMemorySpace();
235  } else if (getCopy()) {
236  auto copyBufferType = bufferization::getBufferType(getCopy(), options,
237  state, invocationStack);
238  if (failed(copyBufferType))
239  return failure();
240  memorySpace = copyBufferType->getMemorySpace();
241  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
242  memorySpace = *ms;
243  } else {
244  return getOperation()->emitError("could not infer memory space");
245  }
246 
247  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
248 }
249 
250 LogicalResult AllocTensorOp::verify() {
251  if (getCopy() && !getDynamicSizes().empty())
252  return emitError("dynamic sizes not needed when copying a tensor");
253  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
254  return emitError("expected ")
255  << getType().getNumDynamicDims() << " dynamic sizes";
256  if (getCopy() && getCopy().getType() != getType())
257  return emitError("expected that `copy` and return type match");
258  return success();
259 }
260 
261 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
262  RankedTensorType type, ValueRange dynamicSizes) {
263  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
264  /*size_hint=*/Value(),
265  /*memory_space=*/IntegerAttr());
266 }
267 
268 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
269  RankedTensorType type, ValueRange dynamicSizes,
270  Value copy) {
271  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
272  /*memory_space=*/IntegerAttr());
273 }
274 
275 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
276  TensorType type, ValueRange dynamicSizes, Value copy,
277  IntegerAttr memorySpace) {
278  build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
279  memorySpace);
280 }
281 
282 namespace {
283 /// Change the type of the result of a `bufferization.alloc_tensor` by making
284 /// the result type statically sized along dimension that in the original
285 /// operation where defined as dynamic, but the size was defined using a
286 /// `constant` op. For example:
287 ///
288 /// %c5 = arith.constant 5: index
289 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
290 ///
291 /// to
292 ///
293 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
294 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
296 
297  LogicalResult matchAndRewrite(AllocTensorOp op,
298  PatternRewriter &rewriter) const override {
299  if (op.getCopy())
300  return failure();
301  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
302  SmallVector<Value> newDynamicSizes;
303  unsigned int dynValCounter = 0;
304  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
305  if (!op.isDynamicDim(i))
306  continue;
307  Value value = op.getDynamicSizes()[dynValCounter++];
308  APInt intVal;
309  if (matchPattern(value, m_ConstantInt(&intVal))) {
310  int64_t dim = intVal.getSExtValue();
311  if (dim >= 0)
312  newShape[i] = intVal.getSExtValue();
313  else
314  newDynamicSizes.push_back(value);
315  } else {
316  newDynamicSizes.push_back(value);
317  }
318  }
319  RankedTensorType newType = RankedTensorType::get(
320  newShape, op.getType().getElementType(), op.getType().getEncoding());
321  if (newType == op.getType())
322  return failure();
323  auto newOp = rewriter.create<AllocTensorOp>(
324  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
325  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
326  return success();
327  }
328 };
329 
330 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
332 
333  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
334  PatternRewriter &rewriter) const override {
335  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
336  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
337  if (!allocTensorOp || !maybeConstantIndex)
338  return failure();
339  if (*maybeConstantIndex < 0 ||
340  *maybeConstantIndex >= allocTensorOp.getType().getRank())
341  return failure();
342  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
343  return failure();
344  rewriter.replaceOp(
345  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
346  return success();
347  }
348 };
349 } // namespace
350 
351 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
352  MLIRContext *ctx) {
353  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
354 }
355 
357  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
358  auto shapes = llvm::to_vector<4>(
359  llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
360  [&](int64_t dim) -> OpFoldResult {
361  if (isDynamicDim(dim))
362  return getDynamicSize(builder, dim);
363  return builder.getIndexAttr(getStaticSize(dim));
364  }));
365  reifiedReturnShapes.emplace_back(std::move(shapes));
366  return success();
367 }
368 
369 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
370  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
371  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
372  parser.parseRParen())
373  return failure();
374  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
375  OpAsmParser::UnresolvedOperand copyOperand;
376  if (copyKeyword.succeeded())
377  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
378  parser.parseRParen())
379  return failure();
380  ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
381  OpAsmParser::UnresolvedOperand sizeHintOperand;
382  if (sizeHintKeyword.succeeded())
383  if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
384  return failure();
385  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
386  return failure();
387 
388  TensorType type;
389  if (parser.parseCustomTypeWithFallback(type))
390  return failure();
391  result.addTypes(type);
392 
393  Type indexType = parser.getBuilder().getIndexType();
394  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
395  return failure();
396  if (copyKeyword.succeeded())
397  if (parser.resolveOperand(copyOperand, type, result.operands))
398  return failure();
399  if (sizeHintKeyword.succeeded())
400  if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
401  return failure();
402  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
404  {static_cast<int32_t>(dynamicSizesOperands.size()),
405  static_cast<int32_t>(copyKeyword.succeeded()),
406  static_cast<int32_t>(sizeHintKeyword.succeeded())}));
407  return success();
408 }
409 
411  p << "(" << getDynamicSizes() << ")";
412  if (getCopy())
413  p << " copy(" << getCopy() << ")";
414  if (getSizeHint())
415  p << " size_hint=" << getSizeHint();
416  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
417  AllocTensorOp::getOperandSegmentSizeAttr()});
418  p << " : ";
419  auto type = getResult().getType();
420  if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
421  p.printStrippedAttrOrType(validType);
422  else
423  p << type;
424 }
425 
426 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
427  assert(isDynamicDim(idx) && "expected dynamic dim");
428  if (getCopy())
429  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
430  return getOperand(getIndexOfDynamicSize(idx));
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // CloneOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
438  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
439 }
440 
441 namespace {
442 
443 /// Merge the clone and its source (by converting the clone to a cast) when
444 /// possible.
445 struct SimplifyClones : public OpRewritePattern<CloneOp> {
447 
448  LogicalResult matchAndRewrite(CloneOp cloneOp,
449  PatternRewriter &rewriter) const override {
450  if (cloneOp.use_empty()) {
451  rewriter.eraseOp(cloneOp);
452  return success();
453  }
454 
455  Value source = cloneOp.getInput();
456  if (source.getType() != cloneOp.getType() &&
457  !memref::CastOp::areCastCompatible({source.getType()},
458  {cloneOp.getType()}))
459  return failure();
460 
461  // Aims to find the dealloc op for the canonical source
462  // which otherwise could prevent removal of unnecessary allocs.
463  Value canonicalSource = source;
464  while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
465  canonicalSource.getDefiningOp()))
466  canonicalSource = iface.getViewSource();
467 
468  std::optional<Operation *> maybeCloneDeallocOp =
469  memref::findDealloc(cloneOp.getOutput());
470  // Skip if either of them has > 1 deallocate operations.
471  if (!maybeCloneDeallocOp.has_value())
472  return failure();
473  std::optional<Operation *> maybeSourceDeallocOp =
474  memref::findDealloc(canonicalSource);
475  if (!maybeSourceDeallocOp.has_value())
476  return failure();
477  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
478  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
479 
480  // If both are deallocated in the same block, their in-block lifetimes
481  // might not fully overlap, so we cannot decide which one to drop.
482  if (cloneDeallocOp && sourceDeallocOp &&
483  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
484  return failure();
485 
486  Block *currentBlock = cloneOp->getBlock();
487  Operation *redundantDealloc = nullptr;
488  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
489  redundantDealloc = cloneDeallocOp;
490  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
491  redundantDealloc = sourceDeallocOp;
492  }
493 
494  if (!redundantDealloc)
495  return failure();
496 
497  // Safety check that there are no other deallocations inbetween
498  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
499  // of source before the uses of the clone. With alias information, we could
500  // restrict this to only fail of the dealloc's operand is an alias
501  // of the source.
502  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
503  pos = pos->getNextNode()) {
504  // Bail if we run out of operations while looking for a deallocation op.
505  if (!pos)
506  return failure();
507  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
508  if (!effectInterface)
509  continue;
510  if (effectInterface.hasEffect<MemoryEffects::Free>())
511  return failure();
512  }
513 
514  if (source.getType() != cloneOp.getType())
515  source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
516  cloneOp.getType(), source);
517  rewriter.replaceOp(cloneOp, source);
518  rewriter.eraseOp(redundantDealloc);
519  return success();
520  }
521 };
522 
523 } // namespace
524 
525 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
526  MLIRContext *context) {
527  results.add<SimplifyClones>(context);
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // DeallocTensorOp
532 //===----------------------------------------------------------------------===//
533 
534 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
536  BufferizationState &state) {
537  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
538  if (failed(buffer))
539  return failure();
540  rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
541  rewriter.eraseOp(getOperation());
542  return success();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // MaterializeInDestinationOp
547 //===----------------------------------------------------------------------===//
548 
549 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
550  OpOperand &opOperand, const AnalysisState &state) {
551  return opOperand == getSourceMutable();
552 }
553 
554 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
555  OpOperand &opOperand, const AnalysisState &state) {
556  if (opOperand == getDestMutable()) {
557  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
558  return true;
559  }
560  return false;
561 }
562 
563 bool MaterializeInDestinationOp::mustBufferizeInPlace(
564  OpOperand &opOperand, const AnalysisState &state) {
565  // The source is only read and not written, so it always bufferizes in-place
566  // by default. The destination is written and is forced to bufferize in-place
567  // (if it is a tensor).
568  return true;
569 }
570 
572 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
573  const AnalysisState &state) {
574  if (opOperand == getDestMutable()) {
575  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
576  return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
577  }
578  return {};
579 }
580 
581 LogicalResult
582 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
584  BufferizationState &state) {
585  bool tensorDest = isa<TensorType>(getDest().getType());
586  Value buffer;
587  if (tensorDest) {
588  FailureOr<Value> maybeBuffer =
589  getBuffer(rewriter, getDest(), options, state);
590  if (failed(maybeBuffer))
591  return failure();
592  buffer = *maybeBuffer;
593  } else {
594  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
595  buffer = getDest();
596  }
597  auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
598  if (failed(srcBuffer))
599  return failure();
600  if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
601  return failure();
602  replaceOpWithBufferizedValues(rewriter, getOperation(),
603  tensorDest ? ValueRange(buffer) : ValueRange());
604  return success();
605 }
606 
607 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
608  const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
609  // As elements are copied from the "source" buffer to the "dest" buffer,
610  // already copied elements are not read a second time.
611  return true;
612 }
613 
615  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
616  if (getOperation()->getNumResults() == 1) {
617  assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
618  reifiedReturnShapes.resize(1,
619  SmallVector<OpFoldResult>(getType().getRank()));
620  reifiedReturnShapes[0] =
621  tensor::getMixedSizes(builder, getLoc(), getDest());
622  }
623  return success();
624 }
625 
627  Location loc) {
628  if (isa<TensorType>(getDest().getType())) {
629  // The subset is the entire destination tensor.
630  return getDest();
631  }
632 
633  // The "restrict" attribute is transferred from this op to the newly created
634  // to_tensor op. If this op does not the "restrict" attribute, the subset
635  // extraction cannot be built because there is no guarantee that there is no
636  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
637  if (!getRestrict())
638  return {};
639 
640  // Build a bufferization.to_tensor op.
641  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
642  assert(getRestrict() &&
643  "expected that ops with memrefs dest have 'restrict'");
644  setRestrict(false);
645  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
646  getWritable());
647 }
648 
649 bool MaterializeInDestinationOp::isEquivalentSubset(
650  Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
651  return equivalenceFn(getDest(), candidate);
652 }
653 
655 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
656  return {getDest()};
657 }
658 
659 OpOperand &MaterializeInDestinationOp::getSourceOperand() {
660  return getOperation()->getOpOperand(0) /*source*/;
661 }
662 
663 bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
664  SubsetOpInterface subsetOp,
665  function_ref<bool(Value, Value)> equivalenceFn) {
666  return false;
667 }
668 
669 bool MaterializeInDestinationOp::operatesOnDisjointSubset(
670  SubsetOpInterface subsetOp,
671  function_ref<bool(Value, Value)> equivalenceFn) {
672  return false;
673 }
674 
675 LogicalResult MaterializeInDestinationOp::verify() {
676  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
677  return emitOpError("'dest' must be a tensor or a memref");
678  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
679  if (getOperation()->getNumResults() != 1)
680  return emitOpError("tensor 'dest' implies exactly one tensor result");
681  if (destType != getResult().getType())
682  return emitOpError("result and 'dest' types must match");
683  }
684  if (isa<BaseMemRefType>(getDest().getType()) &&
685  getOperation()->getNumResults() != 0)
686  return emitOpError("memref 'dest' implies zero results");
687  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
688  return emitOpError("'restrict' is valid only for memref destinations");
689  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
690  return emitOpError("'writable' must be specified if and only if the "
691  "destination is of memref type");
692  TensorType srcType = getSource().getType();
693  ShapedType destType = cast<ShapedType>(getDest().getType());
694  if (srcType.hasRank() != destType.hasRank())
695  return emitOpError("source/destination shapes are incompatible");
696  if (srcType.hasRank()) {
697  if (srcType.getRank() != destType.getRank())
698  return emitOpError("rank mismatch between source and destination shape");
699  for (auto [src, dest] :
700  llvm::zip(srcType.getShape(), destType.getShape())) {
701  if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
702  // Cannot verify dynamic dimension size. Assume that that they match at
703  // runtime.
704  continue;
705  }
706  if (src != dest)
707  return emitOpError("source/destination shapes are incompatible");
708  }
709  }
710  return success();
711 }
712 
713 void MaterializeInDestinationOp::build(OpBuilder &builder,
714  OperationState &state, Value source,
715  Value dest) {
716  auto destTensorType = dyn_cast<TensorType>(dest.getType());
717  build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
718  source, dest);
719 }
720 
721 bool MaterializeInDestinationOp::isWritable(Value value,
722  const AnalysisState &state) {
723  return isa<TensorType>(getDest().getType()) ? true : getWritable();
724 }
725 
726 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
727  return getDestMutable();
728 }
729 
730 void MaterializeInDestinationOp::getEffects(
732  &effects) {
733  if (isa<BaseMemRefType>(getDest().getType()))
734  effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
736 }
737 
738 //===----------------------------------------------------------------------===//
739 // ToTensorOp
740 //===----------------------------------------------------------------------===//
741 
742 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
743  return getWritable();
744 }
745 
746 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
747  if (auto toBuffer = getMemref().getDefiningOp<ToBufferOp>())
748  // Approximate alias analysis by conservatively folding only when no there
749  // is no interleaved operation.
750  if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
751  toBuffer->getNextNode() == this->getOperation())
752  return toBuffer.getTensor();
753  return {};
754 }
755 
756 namespace {
757 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
759 
760  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
761  PatternRewriter &rewriter) const override {
762  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
763  if (!memrefToTensorOp)
764  return failure();
765 
766  rewriter.replaceOpWithNewOp<memref::DimOp>(
767  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
768  return success();
769  }
770 };
771 } // namespace
772 
773 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
774  MLIRContext *context) {
775  results.add<DimOfToTensorFolder>(context);
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // ToBufferOp
780 //===----------------------------------------------------------------------===//
781 
782 OpFoldResult ToBufferOp::fold(FoldAdaptor) {
783  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
784  if (memrefToTensor.getMemref().getType() == getType())
785  return memrefToTensor.getMemref();
786  return {};
787 }
788 
789 namespace {
790 
791 /// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
792 struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
794 
795  LogicalResult matchAndRewrite(ToBufferOp toBuffer,
796  PatternRewriter &rewriter) const final {
797  auto tensorCastOperand =
798  toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
799  if (!tensorCastOperand)
800  return failure();
801  auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
802  tensorCastOperand.getOperand().getType());
803  if (!srcTensorType)
804  return failure();
805  auto memrefType = MemRefType::get(srcTensorType.getShape(),
806  srcTensorType.getElementType());
807  Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
808  tensorCastOperand.getOperand());
809  rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
810  memref);
811  return success();
812  }
813 };
814 
815 /// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
816 /// cast if necessary.
817 struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
819 
820  LogicalResult matchAndRewrite(ToBufferOp toBuffer,
821  PatternRewriter &rewriter) const final {
823  options.bufferAlignment = 0;
824  return foldToBufferToTensorPair(rewriter, toBuffer, options);
825  }
826 };
827 
828 /// Fold a load on a to_buffer operation into an tensor.extract on the
829 /// corresponding tensor.
830 struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
832 
833  LogicalResult matchAndRewrite(memref::LoadOp load,
834  PatternRewriter &rewriter) const override {
835  auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
836  if (!toBuffer)
837  return failure();
838 
839  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
840  load.getIndices());
841  return success();
842  }
843 };
844 
845 /// Fold dim of a to_buffer into the dim of the tensor.
846 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
848 
849  LogicalResult matchAndRewrite(memref::DimOp dimOp,
850  PatternRewriter &rewriter) const override {
851  auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
852  if (!castOp)
853  return failure();
854  Value newSource = castOp.getOperand();
855  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
856  dimOp.getIndex());
857  return success();
858  }
859 };
860 
861 } // namespace
862 
863 void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
864  MLIRContext *context) {
865  results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
866  ToBufferToTensorFolding>(context);
867 }
868 
869 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
871  BufferizationState &state) {
872  // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
873  (void)foldToBufferToTensorPair(rewriter, *this, options);
874  // Note: The return value of `bufferize` indicates whether there was an error
875  // or not. (And not whether the pattern matched or not.)
876  return success();
877 }
878 
879 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
880  Value alloc) {
881  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
882  .getOperation();
883 }
884 
885 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
886  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
887 }
888 
889 //===----------------------------------------------------------------------===//
890 // DeallocOp
891 //===----------------------------------------------------------------------===//
892 
893 LogicalResult DeallocOp::inferReturnTypes(
894  MLIRContext *context, std::optional<::mlir::Location> location,
895  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
896  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
897  DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
898  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
899  IntegerType::get(context, 1));
900  return success();
901 }
902 
903 LogicalResult DeallocOp::verify() {
904  if (getMemrefs().size() != getConditions().size())
905  return emitOpError(
906  "must have the same number of conditions as memrefs to deallocate");
907  if (getRetained().size() != getUpdatedConditions().size())
908  return emitOpError("must have the same number of updated conditions "
909  "(results) as retained operands");
910  return success();
911 }
912 
913 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
914  ValueRange memrefs,
915  ValueRange conditions,
916  PatternRewriter &rewriter) {
917  if (deallocOp.getMemrefs() == memrefs &&
918  deallocOp.getConditions() == conditions)
919  return failure();
920 
921  rewriter.modifyOpInPlace(deallocOp, [&]() {
922  deallocOp.getMemrefsMutable().assign(memrefs);
923  deallocOp.getConditionsMutable().assign(conditions);
924  });
925  return success();
926 }
927 
928 namespace {
929 
930 /// Remove duplicate values in the list of memrefs to be deallocated. We need to
931 /// make sure the corresponding condition value is updated accordingly since
932 /// their two conditions might not cover the same set of cases. In that case, we
933 /// have to combine them (by computing the disjunction of them).
934 /// Example:
935 /// ```mlir
936 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
937 /// ```
938 /// is canonicalized to
939 /// ```mlir
940 /// %0 = arith.ori %arg1, %arg2 : i1
941 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
942 /// ```
943 struct DeallocRemoveDuplicateDeallocMemrefs
944  : public OpRewritePattern<DeallocOp> {
946 
947  LogicalResult matchAndRewrite(DeallocOp deallocOp,
948  PatternRewriter &rewriter) const override {
949  // Unique memrefs to be deallocated.
950  DenseMap<Value, unsigned> memrefToCondition;
951  SmallVector<Value> newMemrefs, newConditions;
952  for (auto [i, memref, cond] :
953  llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
954  if (memrefToCondition.count(memref)) {
955  // If the dealloc conditions don't match, we need to make sure that the
956  // dealloc happens on the union of cases.
957  Value &newCond = newConditions[memrefToCondition[memref]];
958  if (newCond != cond)
959  newCond =
960  rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
961  } else {
962  memrefToCondition.insert({memref, newConditions.size()});
963  newMemrefs.push_back(memref);
964  newConditions.push_back(cond);
965  }
966  }
967 
968  // Return failure if we don't change anything such that we don't run into an
969  // infinite loop of pattern applications.
970  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
971  rewriter);
972  }
973 };
974 
975 /// Remove duplicate values in the list of retained memrefs. We need to make
976 /// sure the corresponding result condition value is replaced properly.
977 /// Example:
978 /// ```mlir
979 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
980 /// ```
981 /// is canonicalized to
982 /// ```mlir
983 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
984 /// ```
985 struct DeallocRemoveDuplicateRetainedMemrefs
986  : public OpRewritePattern<DeallocOp> {
988 
989  LogicalResult matchAndRewrite(DeallocOp deallocOp,
990  PatternRewriter &rewriter) const override {
991  // Unique retained values
993  SmallVector<Value> newRetained;
994  SmallVector<unsigned> resultReplacementIdx;
995  unsigned i = 0;
996  for (auto retained : deallocOp.getRetained()) {
997  if (seen.count(retained)) {
998  resultReplacementIdx.push_back(seen[retained]);
999  continue;
1000  }
1001 
1002  seen[retained] = i;
1003  newRetained.push_back(retained);
1004  resultReplacementIdx.push_back(i++);
1005  }
1006 
1007  // Return failure if we don't change anything such that we don't run into an
1008  // infinite loop of pattern applications.
1009  if (newRetained.size() == deallocOp.getRetained().size())
1010  return failure();
1011 
1012  // We need to create a new op because the number of results is always the
1013  // same as the number of condition operands.
1014  auto newDeallocOp =
1015  rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
1016  deallocOp.getConditions(), newRetained);
1017  SmallVector<Value> replacements(
1018  llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1019  return newDeallocOp.getUpdatedConditions()[idx];
1020  }));
1021  rewriter.replaceOp(deallocOp, replacements);
1022  return success();
1023  }
1024 };
1025 
1026 /// Erase deallocation operations where the variadic list of memrefs to
1027 /// deallocate is empty. Example:
1028 /// ```mlir
1029 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1030 /// ```
1031 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1033 
1034  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1035  PatternRewriter &rewriter) const override {
1036  if (deallocOp.getMemrefs().empty()) {
1037  Value constFalse = rewriter.create<arith::ConstantOp>(
1038  deallocOp.getLoc(), rewriter.getBoolAttr(false));
1039  rewriter.replaceOp(
1040  deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1041  constFalse));
1042  return success();
1043  }
1044  return failure();
1045  }
1046 };
1047 
1048 /// Removes memrefs from the deallocation list if their associated condition is
1049 /// always 'false'.
1050 ///
1051 /// Example:
1052 /// ```
1053 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1054 /// if (%arg2, %false)
1055 /// ```
1056 /// becomes
1057 /// ```
1058 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1059 /// ```
1060 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1062 
1063  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1064  PatternRewriter &rewriter) const override {
1065  SmallVector<Value> newMemrefs, newConditions;
1066  for (auto [memref, cond] :
1067  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1068  if (!matchPattern(cond, m_Zero())) {
1069  newMemrefs.push_back(memref);
1070  newConditions.push_back(cond);
1071  }
1072  }
1073 
1074  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1075  rewriter);
1076  }
1077 };
1078 
1079 /// The `memref.extract_strided_metadata` is often inserted to get the base
1080 /// memref if the operand is not already guaranteed to be the result of a memref
1081 /// allocation operation. This canonicalization pattern removes this extraction
1082 /// operation if the operand is now produced by an allocation operation (e.g.,
1083 /// due to other canonicalizations simplifying the IR).
1084 ///
1085 /// Example:
1086 /// ```mlir
1087 /// %alloc = memref.alloc() : memref<2xi32>
1088 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1089 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1090 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1091 /// ```
1092 /// is canonicalized to
1093 /// ```mlir
1094 /// %alloc = memref.alloc() : memref<2xi32>
1095 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1096 /// ```
1097 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1099 
1100  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1101  PatternRewriter &rewriter) const override {
1102  SmallVector<Value> newMemrefs(
1103  llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1104  auto extractStridedOp =
1105  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1106  if (!extractStridedOp)
1107  return memref;
1108  Value allocMemref = extractStridedOp.getOperand();
1109  auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1110  if (!allocOp)
1111  return memref;
1112  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1113  return allocMemref;
1114  return memref;
1115  }));
1116 
1117  return updateDeallocIfChanged(deallocOp, newMemrefs,
1118  deallocOp.getConditions(), rewriter);
1119  }
1120 };
1121 
1122 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1123 /// other user of the allocated value and the allocating operation can be safely
1124 /// removed. If the same value is present multiple times, this pattern relies on
1125 /// other canonicalization patterns to remove the duplicate first.
1126 ///
1127 /// Example:
1128 /// ```mlir
1129 /// %alloc = memref.alloc() : memref<2xi32>
1130 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1131 /// ```
1132 /// is canonicalized to
1133 /// ```mlir
1134 /// bufferization.dealloc (%arg0 : ...) if (%true)
1135 /// ```
1136 struct RemoveAllocDeallocPairWhenNoOtherUsers
1137  : public OpRewritePattern<DeallocOp> {
1139 
1140  LogicalResult matchAndRewrite(DeallocOp deallocOp,
1141  PatternRewriter &rewriter) const override {
1142  SmallVector<Value> newMemrefs, newConditions;
1143  SmallVector<Operation *> toDelete;
1144  for (auto [memref, cond] :
1145  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1146  if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1147  // Check that it is indeed an allocate effect, that the op has no other
1148  // side effects (which would not allow us to remove the op), and that
1149  // there are no other users.
1150  if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1151  hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1152  memref.hasOneUse()) {
1153  toDelete.push_back(allocOp);
1154  continue;
1155  }
1156  }
1157 
1158  newMemrefs.push_back(memref);
1159  newConditions.push_back(cond);
1160  }
1161 
1162  if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1163  rewriter)))
1164  return failure();
1165 
1166  for (Operation *op : toDelete)
1167  rewriter.eraseOp(op);
1168 
1169  return success();
1170  }
1171 };
1172 
1173 } // anonymous namespace
1174 
1175 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176  MLIRContext *context) {
1178 }
1179 
1181  RewritePatternSet &patterns, MLIRContext *context) {
1182  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1183  DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1184  EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1185  RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // TableGen'd op method definitions
1190 //===----------------------------------------------------------------------===//
1191 
1192 #define GET_OP_CLASSES
1193 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:161
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:98
IndexType getIndexType()
Definition: Builders.cpp:53
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BufferizationState provides information about the state of the IR during the bufferization process.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
The following effect indicates that the operation allocates from some resource.
The following effect indicates that the operation frees some resource that has been allocated.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Options for BufferizableOpInterface-based bufferization.