MLIR  16.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Matchers.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
28  MemRefType destType) {
29  auto srcType = value.getType().cast<MemRefType>();
30 
31  // Element type, rank and memory space must match.
32  if (srcType.getElementType() != destType.getElementType())
33  return failure();
34  if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
35  return failure();
36  if (srcType.getRank() != destType.getRank())
37  return failure();
38 
39  // In case the affine maps are different, we may need to use a copy if we go
40  // from dynamic to static offset or stride (the canonicalization cannot know
41  // at this point that it is really cast compatible).
42  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43  int64_t sourceOffset, targetOffset;
44  SmallVector<int64_t, 4> sourceStrides, targetStrides;
45  if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46  failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47  return false;
48  auto dynamicToStatic = [](int64_t a, int64_t b) {
49  return a == MemRefType::getDynamicStrideOrOffset() &&
50  b != MemRefType::getDynamicStrideOrOffset();
51  };
52  if (dynamicToStatic(sourceOffset, targetOffset))
53  return false;
54  for (auto it : zip(sourceStrides, targetStrides))
55  if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
56  return false;
57  return true;
58  };
59 
60  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
61  // ensure that we only generate casts that always succeed at runtime, we check
62  // a fix extra conditions in `isGuaranteedCastCompatible`.
63  if (memref::CastOp::areCastCompatible(srcType, destType) &&
64  isGuaranteedCastCompatible(srcType, destType)) {
65  Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
66  return casted;
67  }
68 
69  auto loc = value.getLoc();
70  SmallVector<Value, 4> dynamicOperands;
71  for (int i = 0; i < destType.getRank(); ++i) {
72  if (destType.getShape()[i] != ShapedType::kDynamicSize)
73  continue;
74  auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
75  Value size = b.create<memref::DimOp>(loc, value, index);
76  dynamicOperands.push_back(size);
77  }
78  // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
79  // BufferizableOpInterface impl of ToMemrefOp.
80  Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
81  b.create<memref::CopyOp>(loc, value, copy);
82  return copy;
83 }
84 
85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
86 /// to_memref op are different, a memref.cast is needed.
89  ToMemrefOp toMemref) {
90  auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
91  if (!memrefToTensor)
92  return failure();
93 
94  Type srcType = memrefToTensor.getMemref().getType();
95  Type destType = toMemref.getType();
96 
97  // Directly rewrite if the type did not change.
98  if (srcType == destType) {
99  rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
100  return success();
101  }
102 
103  auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104  auto rankedDestType = destType.dyn_cast<MemRefType>();
105  auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106 
107  // Ranked memref -> Ranked memref cast.
108  if (rankedSrcType && rankedDestType) {
110  rewriter, memrefToTensor.getMemref(), rankedDestType);
111  if (failed(replacement))
112  return failure();
113 
114  rewriter.replaceOp(toMemref, *replacement);
115  return success();
116  }
117 
118  // Unranked memref -> Ranked memref cast: May require a copy.
119  // TODO: Not implemented at the moment.
120  if (unrankedSrcType && rankedDestType)
121  return failure();
122 
123  // Unranked memref -> unranked memref cast
124  // Ranked memref -> unranked memref cast: No copy needed.
125  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126  "expected that types are cast compatible");
127  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
128  memrefToTensor.getMemref());
129  return success();
130 }
131 
133  OpBuilder &b, Location loc, Value shapedValue,
134  SmallVector<Value> &dynamicDims) {
135  auto shapedType = shapedValue.getType().cast<ShapedType>();
136  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137  if (shapedType.isDynamicDim(i)) {
138  if (shapedType.isa<MemRefType>()) {
139  dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140  } else {
141  assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142  dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143  }
144  }
145  }
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AllocTensorOp
150 //===----------------------------------------------------------------------===//
151 
152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153  const BufferizationOptions &options) {
154  OpBuilder::InsertionGuard g(rewriter);
155  Operation *op = this->getOperation();
156  Location loc = getLoc();
157 
158  // Nothing to do for dead AllocTensorOps.
159  if (getOperation()->getUses().empty()) {
160  rewriter.eraseOp(getOperation());
161  return success();
162  }
163 
164  // Get "copy" buffer.
165  Value copyBuffer;
166  if (getCopy()) {
167  FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168  if (failed(maybeCopyBuffer))
169  return failure();
170  copyBuffer = *maybeCopyBuffer;
171  }
172 
173  // Compute memory space of this allocation.
174  unsigned memorySpace;
175  if (getMemorySpace().has_value()) {
176  memorySpace = *getMemorySpace();
177  } else if (getCopy()) {
178  memorySpace =
180  } else if (options.defaultMemorySpace.has_value()) {
181  memorySpace = *options.defaultMemorySpace;
182  } else {
183  return op->emitError("could not infer memory space");
184  }
185 
186  // Create memory allocation.
187  auto allocType =
188  MemRefType::get(getType().getShape(), getType().getElementType(),
189  AffineMap(), memorySpace);
190  SmallVector<Value> dynamicDims = getDynamicSizes();
191  if (getCopy()) {
192  assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
193  populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
194  }
195  FailureOr<Value> alloc =
196  options.createAlloc(rewriter, loc, allocType, dynamicDims);
197  if (failed(alloc))
198  return failure();
199 
200  // Create memory copy (if any).
201  if (getCopy()) {
202  if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
203  return failure();
204  }
205 
206  // Should the buffer be deallocated?
207  bool dealloc =
208  shouldDeallocateOpResult(getResult().cast<OpResult>(), options);
209 
210  // Replace op.
211  replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
212 
213  // Create buffer deallocation (if requested).
214  if (!dealloc)
215  return success();
216 
217  rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
218  if (failed(options.createDealloc(rewriter, loc, *alloc)))
219  return failure();
220  return success();
221 }
222 
224  const AnalysisState &state) {
225  // AllocTensorOps do not write unless they have a `copy` value.
226  return static_cast<bool>(getCopy());
227 }
228 
229 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
230  const AnalysisState &state) {
231  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
232  "expected copy operand");
233  return true;
234 }
235 
236 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
237  const AnalysisState &state) {
238  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
239  "expected copy operand");
240  return false;
241 }
242 
244 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
245  const AnalysisState &state) {
246  // This is a new allocation. It does not alias with any other buffer.
247  return {};
248 }
249 
251  if (getCopy() && !getDynamicSizes().empty())
252  return emitError("dynamic sizes not needed when copying a tensor");
253  if (!getCopy() && getType().getNumDynamicDims() !=
254  static_cast<int64_t>(getDynamicSizes().size()))
255  return emitError("expected ")
256  << getType().getNumDynamicDims() << " dynamic sizes";
257  if (getCopy() && getCopy().getType() != getType())
258  return emitError("expected that `copy` and return type match");
259 
260  // For sparse tensor allocation, we require that none of its
261  // uses escapes the function boundary directly.
263  for (auto &use : getOperation()->getUses())
264  if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
265  use.getOwner()))
266  return emitError("sparse tensor allocation should not escape function");
267  }
268 
269  return success();
270 }
271 
272 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
273  RankedTensorType type, ValueRange dynamicSizes) {
274  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
275  /*memory_space=*/IntegerAttr());
276 }
277 
278 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
279  RankedTensorType type, ValueRange dynamicSizes,
280  Value copy) {
281  build(builder, result, type, dynamicSizes, copy,
282  /*memory_space=*/IntegerAttr());
283 }
284 
285 namespace {
286 /// Change the type of the result of a `bufferization.alloc_tensor` by making
287 /// the result type statically sized along dimension that in the original
288 /// operation where defined as dynamic, but the size was defined using a
289 /// `constant` op. For example:
290 ///
291 /// %c5 = arith.constant 5: index
292 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
293 ///
294 /// to
295 ///
296 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
297 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
299 
300  LogicalResult matchAndRewrite(AllocTensorOp op,
301  PatternRewriter &rewriter) const override {
302  if (op.getCopy())
303  return failure();
304  SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
305  SmallVector<Value> newDynamicSizes;
306  unsigned int dynValCounter = 0;
307  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
308  if (!op.isDynamicDim(i))
309  continue;
310  Value value = op.getDynamicSizes()[dynValCounter++];
311  APInt intVal;
312  if (matchPattern(value, m_ConstantInt(&intVal))) {
313  newShape[i] = intVal.getSExtValue();
314  } else {
315  newDynamicSizes.push_back(value);
316  }
317  }
318  RankedTensorType newType = RankedTensorType::get(
319  newShape, op.getType().getElementType(), op.getType().getEncoding());
320  if (newType == op.getType())
321  return failure();
322  auto newOp = rewriter.create<AllocTensorOp>(
323  op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325  return success();
326  }
327 };
328 
329 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
331 
332  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333  PatternRewriter &rewriter) const override {
334  Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
335  auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336  if (!allocTensorOp || !maybeConstantIndex)
337  return failure();
338  if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
339  return failure();
340  rewriter.replaceOp(
341  dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
342  return success();
343  }
344 };
345 } // namespace
346 
347 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
348  MLIRContext *ctx) {
349  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
350 }
351 
352 LogicalResult AllocTensorOp::reifyResultShapes(
353  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
354  auto shapes = llvm::to_vector<4>(llvm::map_range(
355  llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
356  if (isDynamicDim(dim))
357  return getDynamicSize(builder, dim);
358  return builder.create<arith::ConstantIndexOp>(getLoc(),
359  getStaticSize(dim));
360  }));
361  reifiedReturnShapes.emplace_back(std::move(shapes));
362  return success();
363 }
364 
365 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
366  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
367  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
368  parser.parseRParen())
369  return failure();
370  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
371  OpAsmParser::UnresolvedOperand copyOperand;
372  if (copyKeyword.succeeded())
373  if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
374  parser.parseRParen())
375  return failure();
376  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
377  return failure();
378 
379  TensorType type;
380  if (parser.parseCustomTypeWithFallback(type))
381  return failure();
382  result.addTypes(type);
383 
384  Type indexType = parser.getBuilder().getIndexType();
385  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
386  return failure();
387  if (copyKeyword.succeeded())
388  if (parser.resolveOperand(copyOperand, type, result.operands))
389  return failure();
390  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
391  parser.getBuilder().getI32VectorAttr(
392  {static_cast<int32_t>(dynamicSizesOperands.size()),
393  static_cast<int32_t>(copyKeyword.succeeded())}));
394  return success();
395 }
396 
398  p << "(" << getDynamicSizes() << ")";
399  if (getCopy())
400  p << " copy(" << getCopy() << ")";
401  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
402  AllocTensorOp::getOperandSegmentSizeAttr()});
403  p << " : ";
404  auto type = getResult().getType();
405  if (auto validType = type.dyn_cast<::mlir::TensorType>())
406  p.printStrippedAttrOrType(validType);
407  else
408  p << type;
409 }
410 
411 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
412  assert(isDynamicDim(idx) && "expected dynamic dim");
413  if (getCopy())
414  return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
415  return getOperand(getIndexOfDynamicSize(idx));
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // CloneOp
420 //===----------------------------------------------------------------------===//
421 
422 void CloneOp::getEffects(
424  &effects) {
425  effects.emplace_back(MemoryEffects::Read::get(), getInput(),
427  effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
429  effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
431 }
432 
433 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
434  return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
435 }
436 
437 namespace {
438 
439 /// Merge the clone and its source (by converting the clone to a cast) when
440 /// possible.
441 struct SimplifyClones : public OpRewritePattern<CloneOp> {
443 
444  LogicalResult matchAndRewrite(CloneOp cloneOp,
445  PatternRewriter &rewriter) const override {
446  if (cloneOp.use_empty()) {
447  rewriter.eraseOp(cloneOp);
448  return success();
449  }
450 
451  Value source = cloneOp.getInput();
452 
453  // This only finds dealloc operations for the immediate value. It should
454  // also consider aliases. That would also make the safety check below
455  // redundant.
456  llvm::Optional<Operation *> maybeCloneDeallocOp =
457  memref::findDealloc(cloneOp.getOutput());
458  // Skip if either of them has > 1 deallocate operations.
459  if (!maybeCloneDeallocOp.has_value())
460  return failure();
461  llvm::Optional<Operation *> maybeSourceDeallocOp =
462  memref::findDealloc(source);
463  if (!maybeSourceDeallocOp.has_value())
464  return failure();
465  Operation *cloneDeallocOp = *maybeCloneDeallocOp;
466  Operation *sourceDeallocOp = *maybeSourceDeallocOp;
467 
468  // If both are deallocated in the same block, their in-block lifetimes
469  // might not fully overlap, so we cannot decide which one to drop.
470  if (cloneDeallocOp && sourceDeallocOp &&
471  cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
472  return failure();
473 
474  Block *currentBlock = cloneOp->getBlock();
475  Operation *redundantDealloc = nullptr;
476  if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
477  redundantDealloc = cloneDeallocOp;
478  } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
479  redundantDealloc = sourceDeallocOp;
480  }
481 
482  if (!redundantDealloc)
483  return failure();
484 
485  // Safety check that there are no other deallocations inbetween
486  // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
487  // of source before the uses of the clone. With alias information, we could
488  // restrict this to only fail of the dealloc's operand is an alias
489  // of the source.
490  for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
491  pos = pos->getNextNode()) {
492  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
493  if (!effectInterface)
494  continue;
495  if (effectInterface.hasEffect<MemoryEffects::Free>())
496  return failure();
497  }
498 
499  rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
500  source);
501  rewriter.eraseOp(redundantDealloc);
502  return success();
503  }
504 };
505 
506 } // namespace
507 
508 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
509  MLIRContext *context) {
510  results.add<SimplifyClones>(context);
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // DeallocTensorOp
515 //===----------------------------------------------------------------------===//
516 
517 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
518  const BufferizationOptions &options) {
519  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
520  if (failed(buffer))
521  return failure();
522  if (failed(options.createDealloc(rewriter, getLoc(), *buffer)))
523  return failure();
524  rewriter.eraseOp(getOperation());
525  return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // ToTensorOp
530 //===----------------------------------------------------------------------===//
531 
532 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
533  if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
534  // Approximate alias analysis by conservatively folding only when no there
535  // is no interleaved operation.
536  if (toMemref->getBlock() == this->getOperation()->getBlock() &&
537  toMemref->getNextNode() == this->getOperation())
538  return toMemref.getTensor();
539  return {};
540 }
541 
542 namespace {
543 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
545 
546  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
547  PatternRewriter &rewriter) const override {
548  auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
549  if (!memrefToTensorOp)
550  return failure();
551 
552  rewriter.replaceOpWithNewOp<memref::DimOp>(
553  dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
554  return success();
555  }
556 };
557 } // namespace
558 
559 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
560  MLIRContext *context) {
561  results.add<DimOfToTensorFolder>(context);
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // ToMemrefOp
566 //===----------------------------------------------------------------------===//
567 
568 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
569  if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
570  if (memrefToTensor.getMemref().getType() == getType())
571  return memrefToTensor.getMemref();
572  return {};
573 }
574 
575 namespace {
576 
577 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
578 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
580 
581  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
582  PatternRewriter &rewriter) const final {
583  auto tensorCastOperand =
584  toMemref.getOperand().getDefiningOp<tensor::CastOp>();
585  if (!tensorCastOperand)
586  return failure();
587  auto srcTensorType =
588  tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
589  if (!srcTensorType)
590  return failure();
591  auto memrefType = MemRefType::get(srcTensorType.getShape(),
592  srcTensorType.getElementType());
593  Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
594  tensorCastOperand.getOperand());
595  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
596  memref);
597  return success();
598  }
599 };
600 
601 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
602 /// cast if necessary.
603 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
605 
606  LogicalResult matchAndRewrite(ToMemrefOp toMemref,
607  PatternRewriter &rewriter) const final {
608  return foldToMemrefToTensorPair(rewriter, toMemref);
609  }
610 };
611 
612 /// Fold a load on a to_memref operation into an tensor.extract on the
613 /// corresponding tensor.
614 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
616 
617  LogicalResult matchAndRewrite(memref::LoadOp load,
618  PatternRewriter &rewriter) const override {
619  auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
620  if (!toMemref)
621  return failure();
622 
623  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
624  load.getIndices());
625  return success();
626  }
627 };
628 
629 /// Fold dim of a to_memref into the dim of the tensor.
630 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
632 
633  LogicalResult matchAndRewrite(memref::DimOp dimOp,
634  PatternRewriter &rewriter) const override {
635  auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
636  if (!castOp)
637  return failure();
638  Value newSource = castOp.getOperand();
639  rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
640  dimOp.getIndex());
641  return success();
642  }
643 };
644 
645 } // namespace
646 
647 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
648  MLIRContext *context) {
649  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
650  ToMemrefToTensorFolding>(context);
651 }
652 
653 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
654  const BufferizationOptions &options) {
655  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
656  (void)foldToMemrefToTensorPair(rewriter, *this);
657  // Note: The return value of `bufferize` indicates whether there was an error
658  // or not. (And not whether the pattern matched or not.)
659  return success();
660 }
661 
662 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
663  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
664  .getOperation();
665 }
666 
667 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
668  return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // TableGen'd op method definitions
673 //===----------------------------------------------------------------------===//
674 
675 #define GET_OP_CLASSES
676 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
Include the generated interface declarations.
virtual ParseResult parseLParen()=0
Parse a ( token.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:466
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:388
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:355
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
This is a value defined by a result of an operation.
Definition: Value.h:425
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
This is the representation of an operand reference.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
FailureOr< Value > createAlloc(OpBuilder &b, Location loc, MemRefType type, ValueRange dynShape) const
Create a memref allocation with the given type and dynamic extents.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
bool shouldDeallocateOpResult(OpResult opResult, const BufferizationOptions &options)
Return true if the buffer of given OpResult should be deallocated.
static constexpr const bool value
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static DefaultResource * get()
Returns a unique instance for the given effect class.
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
AnalysisState provides a variety of helper functions for dealing with tensor values.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const
Creates a memcpy between two given buffers.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
U dyn_cast() const
Definition: Types.h:270
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseRParen()=0
Parse a ) token.
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref)
Try to fold to_memref(to_tensor(x)).
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
This class represents a specific instance of an effect.
Options for BufferizableOpInterface-based bufferization.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
static llvm::ManagedStatic< PassManagerOptions > options
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:112
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
Optional< unsigned > defaultMemorySpace
The default memory space that should be used when it cannot be inferred from the context.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
The following effect indicates that the operation frees some resource that has been allocated...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
This class represents an operand of an operation.
Definition: Value.h:251
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
Optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
This class represents success/failure for parsing-like operations that find it important to chain tog...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:89
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static bool isMemoryWrite(Value value, const AnalysisState &state)
Return true if the given tensor value is a memory write.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const
Creates a memref deallocation.
U cast() const
Definition: Types.h:278