MLIR  21.0.0git
BufferizableOpInterface.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/Debug.h"
23 
24 //===----------------------------------------------------------------------===//
25 // BufferizableOpInterface
26 //===----------------------------------------------------------------------===//
27 
28 namespace mlir {
29 namespace bufferization {
30 
31 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
32 
33 } // namespace bufferization
34 } // namespace mlir
35 
37 
38 #define DEBUG_TYPE "bufferizable-op-interface"
39 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
40 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
41 
42 using namespace mlir;
43 using namespace bufferization;
44 
45 static bool isRepetitiveRegion(Region *region,
47  Operation *op = region->getParentOp();
48  if (auto bufferizableOp = options.dynCastBufferizableOp(op))
49  if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
50  return true;
51  return false;
52 }
53 
56  if (!op->getBlock())
57  return nullptr;
58  if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
59  iter != enclosingRepetitiveRegionCache.end())
60  return iter->second;
61  return enclosingRepetitiveRegionCache[op] =
62  getEnclosingRepetitiveRegion(op->getBlock(), options);
63 }
64 
66  Value value, const BufferizationOptions &options) {
67  if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
68  iter != enclosingRepetitiveRegionCache.end())
69  return iter->second;
70 
71  Region *region = value.getParentRegion();
72  // Collect all visited regions since we only know the repetitive region we
73  // want to map it to later on
74  SmallVector<Region *> visitedRegions;
75  while (region) {
76  visitedRegions.push_back(region);
77  if (isRepetitiveRegion(region, options))
78  break;
79  region = region->getParentRegion();
80  }
81  enclosingRepetitiveRegionCache[value] = region;
82  for (Region *r : visitedRegions)
83  enclosingRepetitiveRegionCache[r] = region;
84  return region;
85 }
86 
88  Block *block, const BufferizationOptions &options) {
89  if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
90  iter != enclosingRepetitiveRegionCache.end())
91  return iter->second;
92 
93  Region *region = block->getParent();
94  Operation *op = nullptr;
95  // Collect all visited regions since we only know the repetitive region we
96  // want to map it to later on
97  SmallVector<Region *> visitedRegions;
98  do {
99  op = region->getParentOp();
100  if (isRepetitiveRegion(region, options))
101  break;
102  } while ((region = op->getParentRegion()));
103 
104  enclosingRepetitiveRegionCache[block] = region;
105  for (Region *r : visitedRegions)
106  enclosingRepetitiveRegionCache[r] = region;
107  return region;
108 }
109 
111  Operation *op1) {
112  auto key = std::make_pair(op0, op1);
113  if (auto iter = insideMutuallyExclusiveRegionsCache.find(key);
114  iter != insideMutuallyExclusiveRegionsCache.end())
115  return iter->second;
116  bool result = ::mlir::insideMutuallyExclusiveRegions(op0, op1);
117  // Populate results for both orderings of the ops.
118  insideMutuallyExclusiveRegionsCache[key] = result;
119  insideMutuallyExclusiveRegionsCache[std::make_pair(op1, op0)] = result;
120  return result;
121 }
122 
124  enclosingRepetitiveRegionCache.clear();
125  insideMutuallyExclusiveRegionsCache.clear();
126 }
127 
129  Region *region, const BufferizationOptions &options) {
130  assert(isRepetitiveRegion(region, options) && "expected repetitive region");
131  while ((region = region->getParentRegion())) {
132  if (isRepetitiveRegion(region, options))
133  break;
134  }
135  return region;
136 }
137 
139  const BufferizationOptions &options) {
140  while (region) {
141  auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp());
142  if (bufferizableOp &&
143  bufferizableOp.isParallelRegion(region->getRegionNumber())) {
144  assert(isRepetitiveRegion(region, options) &&
145  "expected that all parallel regions are also repetitive regions");
146  return region;
147  }
148  region = region->getParentRegion();
149  }
150  return nullptr;
151 }
152 
154  if (auto opResult = llvm::dyn_cast<OpResult>(value))
155  return opResult.getDefiningOp();
156  return llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
157 }
158 
159 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
160 /// shaped value is copied. Otherwise, a tensor with undefined contents is
161 /// allocated.
163  OpBuilder &b, Location loc, Value shapedValue,
164  const BufferizationOptions &options, bool copy) {
165  Value tensor;
166  if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
167  tensor = shapedValue;
168  } else if (llvm::isa<MemRefType>(shapedValue.getType())) {
169  tensor = b.create<ToTensorOp>(loc, shapedValue);
170  } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
171  llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
172  return getOwnerOfValue(shapedValue)
173  ->emitError("copying of unranked tensors is not implemented");
174  } else {
175  llvm_unreachable("expected RankedTensorType or MemRefType");
176  }
177  RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
178  SmallVector<Value> dynamicSizes;
179  if (!copy) {
180  // Compute the dynamic part of the shape.
181  // First try to query the shape via ReifyRankedShapedTypeOpInterface.
182  bool reifiedShapes = false;
183  if (llvm::isa<RankedTensorType>(shapedValue.getType()) &&
184  llvm::isa<OpResult>(shapedValue)) {
185  ReifiedRankedShapedTypeDims resultDims;
186  if (succeeded(
187  reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) {
188  reifiedShapes = true;
189  auto &shape =
190  resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
191  for (const auto &dim : enumerate(tensorType.getShape()))
192  if (ShapedType::isDynamic(dim.value()))
193  dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
194  }
195  }
196 
197  // If the shape could not be reified, create DimOps.
198  if (!reifiedShapes)
199  populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
200  }
201 
202  // Create AllocTensorOp.
203  auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
204  copy ? tensor : Value());
205 
206  // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
207  if (copy)
208  return allocTensorOp.getResult();
209  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
210  if (failed(copyBufferType))
211  return failure();
212  std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
213  if (!memorySpace)
214  memorySpace = options.defaultMemorySpaceFn(tensorType);
215  if (memorySpace.has_value())
216  allocTensorOp.setMemorySpaceAttr(memorySpace.value());
217  return allocTensorOp.getResult();
218 }
219 
220 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
221  RewriterBase &rewriter, const AnalysisState &state) {
222  OpBuilder::InsertionGuard g(rewriter);
223  Operation *op = getOperation();
224  SmallVector<OpOperand *> outOfPlaceOpOperands;
225  DenseSet<OpOperand *> copiedOpOperands;
226  SmallVector<Value> outOfPlaceValues;
227  DenseSet<Value> copiedOpValues;
228 
229  // Find all out-of-place OpOperands.
230  for (OpOperand &opOperand : op->getOpOperands()) {
231  Type operandType = opOperand.get().getType();
232  if (!llvm::isa<TensorType>(operandType))
233  continue;
234  if (state.isInPlace(opOperand))
235  continue;
236  if (llvm::isa<UnrankedTensorType>(operandType))
237  return op->emitError("copying of unranked tensors is not implemented");
238 
239  AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
240  if (aliasingValues.getNumAliases() == 1 &&
241  isa<OpResult>(aliasingValues.getAliases()[0].value) &&
242  !state.bufferizesToMemoryWrite(opOperand) &&
243  state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
244  .getNumAliases() == 1 &&
245  !isa<UnrankedTensorType>(
246  aliasingValues.getAliases()[0].value.getType())) {
247  // The op itself does not write but may create exactly one alias. Instead
248  // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
249  // be smaller than the OpOperand (e.g., in the case of an extract_slice,
250  // where the result is usually a smaller part of the source). Do not apply
251  // this optimization if the OpResult is an unranked tensor (because those
252  // cannot be copied at the moment).
253  Value value = aliasingValues.getAliases()[0].value;
254  outOfPlaceValues.push_back(value);
255  if (!state.canOmitTensorCopy(opOperand))
256  copiedOpValues.insert(value);
257  } else {
258  // In all other cases, make a copy of the OpOperand.
259  outOfPlaceOpOperands.push_back(&opOperand);
260  if (!state.canOmitTensorCopy(opOperand))
261  copiedOpOperands.insert(&opOperand);
262  }
263  }
264 
265  // Insert copies of OpOperands.
266  rewriter.setInsertionPoint(op);
267  for (OpOperand *opOperand : outOfPlaceOpOperands) {
268  FailureOr<Value> copy = allocateTensorForShapedValue(
269  rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
270  copiedOpOperands.contains(opOperand));
271  if (failed(copy))
272  return failure();
273  rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
274  }
275 
276  // Insert copies of Values.
277  rewriter.setInsertionPointAfter(op);
278  for (Value value : outOfPlaceValues) {
279  FailureOr<Value> copy = allocateTensorForShapedValue(
280  rewriter, op->getLoc(), value, state.getOptions(),
281  copiedOpValues.count(value));
282  if (failed(copy))
283  return failure();
284  SmallVector<OpOperand *> uses = llvm::to_vector(
285  llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
286  for (OpOperand *use : uses) {
287  // Do not update the alloc_tensor op that we just created.
288  if (use->getOwner() == copy->getDefiningOp())
289  continue;
290  // tensor.dim ops may have been created to be used as alloc_tensor op
291  // dynamic extents. Do not update these either.
292  if (isa<tensor::DimOp>(use->getOwner()))
293  continue;
294  rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
295  }
296  }
297 
298  return success();
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // OpFilter
303 //===----------------------------------------------------------------------===//
304 
306  // All other ops: Allow/disallow according to filter.
307  bool isAllowed = !hasAllowRule();
308  for (const Entry &entry : entries) {
309  bool filterResult = entry.fn(op);
310  switch (entry.type) {
311  case Entry::ALLOW:
312  isAllowed |= filterResult;
313  break;
314  case Entry::DENY:
315  if (filterResult)
316  // DENY filter matches. This op is no allowed. (Even if other ALLOW
317  // filters may match.)
318  return false;
319  };
320  }
321  return isAllowed;
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // BufferizationOptions
326 //===----------------------------------------------------------------------===//
327 
328 namespace {
329 
330 /// Default function arg type converter: Use a fully dynamic layout map.
332 defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
333  func::FuncOp funcOp,
334  const BufferizationOptions &options) {
335  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
336 }
337 /// Default unknown type converter: Use a fully dynamic layout map.
339 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
340  const BufferizationOptions &options) {
342  llvm::cast<TensorType>(value.getType()), memorySpace);
343 }
344 
345 } // namespace
346 
347 // Default constructor for BufferizationOptions.
349  : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
350  unknownTypeConverterFn(defaultUnknownTypeConverter) {}
351 
353  // Special case: If function boundary bufferization is deactivated, do not
354  // allow ops that belong to the `func` dialect.
355  bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
356  if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
357  return false;
358 
359  return opFilter.isOpAllowed(op);
360 }
361 
362 BufferizableOpInterface
364  if (!isOpAllowed(op))
365  return nullptr;
366  auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
367  if (!bufferizableOp)
368  return nullptr;
369  return bufferizableOp;
370 }
371 
372 BufferizableOpInterface
374  return dynCastBufferizableOp(getOwnerOfValue(value));
375 }
376 
378  LayoutMapOption layoutMapOption) {
379  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
380  func::FuncOp funcOp,
381  const BufferizationOptions &options) {
382  if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
384  memorySpace);
386  memorySpace);
387  };
389  layoutMapOption == LayoutMapOption::InferLayoutMap;
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // Helper functions for BufferizableOpInterface
394 //===----------------------------------------------------------------------===//
395 
396 static void setInsertionPointAfter(OpBuilder &b, Value value) {
397  if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
398  b.setInsertionPointToStart(bbArg.getOwner());
399  } else {
401  }
402 }
403 
404 /// Determine which OpOperand* will alias with `value` if the op is bufferized
405 /// in place. Return all tensor OpOperand* if the op is not bufferizable.
407  if (Operation *op = getOwnerOfValue(value))
408  if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
409  return bufferizableOp.getAliasingOpOperands(value, *this);
410 
411  // The op is not bufferizable.
413 }
414 
415 /// Determine which Values will alias with `opOperand` if the op is bufferized
416 /// in place. Return all tensor Values if the op is not bufferizable.
418  if (auto bufferizableOp =
419  getOptions().dynCastBufferizableOp(opOperand.getOwner()))
420  return bufferizableOp.getAliasingValues(opOperand, *this);
421 
422  // The op is not bufferizable.
423  return detail::unknownGetAliasingValues(opOperand);
424 }
425 
426 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
427 /// op is not bufferizable.
429  if (auto bufferizableOp =
430  getOptions().dynCastBufferizableOp(opOperand.getOwner()))
431  return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
432 
433  // Unknown op that returns a tensor. The inplace analysis does not support it.
434  // Conservatively return true.
435  return true;
436 }
437 
438 /// Return true if `opOperand` bufferizes to a memory write. Return
439 /// `true` if the op is not bufferizable.
441  if (auto bufferizableOp =
442  getOptions().dynCastBufferizableOp(opOperand.getOwner()))
443  return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
444 
445  // Unknown op that returns a tensor. The inplace analysis does not support it.
446  // Conservatively return true.
447  return true;
448 }
449 
450 /// Return true if `opOperand` does neither read nor write but bufferizes to an
451 /// alias. Return false if the op is not bufferizable.
453  if (auto bufferizableOp =
454  getOptions().dynCastBufferizableOp(opOperand.getOwner()))
455  return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
456 
457  // Unknown op that returns a tensor. The inplace analysis does not support it.
458  // Conservatively return false.
459  return false;
460 }
461 
463  auto opResult = llvm::dyn_cast<OpResult>(value);
464  if (!opResult)
465  return true;
466  auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
467  if (!bufferizableOp)
468  return true;
469  return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
470 }
471 
472 /// Return true if the given value is read by an op that bufferizes to a memory
473 /// read. Also takes into account ops that create an alias but do not read by
474 /// themselves (e.g., ExtractSliceOp).
476  assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
477  SmallVector<OpOperand *> workingSet;
478  DenseSet<OpOperand *> visited;
479  for (OpOperand &use : value.getUses())
480  workingSet.push_back(&use);
481 
482  while (!workingSet.empty()) {
483  OpOperand *uMaybeReading = workingSet.pop_back_val();
484  if (!visited.insert(uMaybeReading).second)
485  continue;
486 
487  // Skip over all ops that neither read nor write (but create an alias).
488  if (bufferizesToAliasOnly(*uMaybeReading))
489  for (AliasingValue alias : getAliasingValues(*uMaybeReading))
490  for (OpOperand &use : alias.value.getUses())
491  workingSet.push_back(&use);
492  if (bufferizesToMemoryRead(*uMaybeReading))
493  return true;
494  }
495 
496  return false;
497 }
498 
499 // Starting from `opOperand`, follow the use-def chain in reverse, always
500 // selecting the aliasing OpOperands. Find and return Values for which
501 // `condition` evaluates to true. Uses of such matching Values are not
502 // traversed any further, the visited aliasing opOperands will be preserved
503 // through `visitedOpOperands`.
505  OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
507  llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
508  llvm::DenseSet<Value> visited;
509  llvm::SetVector<Value> result, workingSet;
510  workingSet.insert(opOperand->get());
511 
512  if (visitedOpOperands)
513  visitedOpOperands->insert(opOperand);
514 
515  while (!workingSet.empty()) {
516  Value value = workingSet.pop_back_val();
517 
518  if (!config.revisitAlreadyVisitedValues && visited.contains(value)) {
519  // Stop traversal if value was already visited.
520  if (config.alwaysIncludeLeaves)
521  result.insert(value);
522  continue;
523  }
524  visited.insert(value);
525 
526  if (condition(value)) {
527  result.insert(value);
528  continue;
529  }
530 
531  if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
532  // Stop iterating if `followUnknownOps` is unset and the op is either
533  // not bufferizable or excluded in the OpFilter.
534  if (config.alwaysIncludeLeaves)
535  result.insert(value);
536  continue;
537  }
538 
540  if (aliases.getNumAliases() == 0) {
541  // The traversal ends naturally if there are no more OpOperands that
542  // could be followed.
543  if (config.alwaysIncludeLeaves)
544  result.insert(value);
545  continue;
546  }
547 
548  for (AliasingOpOperand a : aliases) {
549  if (config.followEquivalentOnly &&
550  a.relation != BufferRelation::Equivalent) {
551  // Stop iterating if `followEquivalentOnly` is set but the alias is not
552  // equivalent.
553  if (config.alwaysIncludeLeaves)
554  result.insert(value);
555  continue;
556  }
557 
558  if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
559  // Stop iterating if `followInPlaceOnly` is set but the alias is
560  // out-of-place.
561  if (config.alwaysIncludeLeaves)
562  result.insert(value);
563  continue;
564  }
565 
566  if (config.followSameTypeOrCastsOnly &&
567  a.opOperand->get().getType() != value.getType() &&
568  !value.getDefiningOp<CastOpInterface>()) {
569  // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
570  // has a different type and the op is not a cast.
571  if (config.alwaysIncludeLeaves)
572  result.insert(value);
573  continue;
574  }
575 
576  workingSet.insert(a.opOperand->get());
577  if (visitedOpOperands)
578  visitedOpOperands->insert(a.opOperand);
579  }
580  }
581 
582  return result;
583 }
584 
585 // Find the values that define the contents of the given operand's value.
589  config.alwaysIncludeLeaves = false;
591  opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
592  config);
593 }
594 
597 
599  : options(options), type(type) {
601  options.stateInitializers)
602  fn(*this);
603 }
604 
606  // Do not copy if the tensor has undefined contents.
607  if (hasUndefinedContents(&opOperand))
608  return true;
609 
610  // Do not copy if the buffer of the tensor is entirely overwritten (with
611  // values that do not depend on the old tensor).
612  if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
613  return true;
614 
615  // Do not copy if the tensor is never read.
616  AliasingValueList aliases = getAliasingValues(opOperand);
617  if (!bufferizesToMemoryRead(opOperand) &&
618  llvm::none_of(aliases,
619  [&](AliasingValue a) { return isValueRead(a.value); }))
620  return true;
621 
622  // Default: Cannot omit the copy.
623  return false;
624 }
625 
626 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
627  // ToMemrefOps are always in-place.
628  if (isa<ToMemrefOp>(opOperand.getOwner()))
629  return true;
630 
631  // In the absence of analysis information, OpOperands that bufferize to a
632  // memory write are out-of-place, i.e., an alloc and copy is inserted.
633  return !bufferizesToMemoryWrite(opOperand);
634 }
635 
637  // In the absence of analysis information, we do not know if the values are
638  // equivalent. The conservative answer is "false".
639  return false;
640 }
641 
643  // In the absence of analysis information, we do not know if the values may be
644  // aliasing. The conservative answer is "true".
645  return true;
646 }
647 
649  // In the absence of analysis information, the conservative answer is "false".
650  return false;
651 }
652 
653 // bufferization.to_memref is not allowed to change the rank.
654 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
655 #ifndef NDEBUG
656  auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
657  assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
658  rankedTensorType.getRank()) &&
659  "to_memref would be invalid: mismatching ranks");
660 #endif
661 }
662 
663 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
664  const BufferizationOptions &options) {
665 #ifndef NDEBUG
666  auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
667  assert(tensorType && "unexpected non-tensor type");
668 #endif // NDEBUG
669 
670  // Replace "%t = to_tensor %m" with %m.
671  if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
672  return toTensorOp.getMemref();
673 
674  // Insert to_memref op.
675  OpBuilder::InsertionGuard g(rewriter);
676  setInsertionPointAfter(rewriter, value);
677  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
678  if (failed(memrefType))
679  return failure();
680  ensureToMemrefOpIsValid(value, *memrefType);
681  return rewriter
682  .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
683  .getResult();
684 }
685 
686 /// Return the buffer type for a given Value (tensor) after bufferization.
687 FailureOr<BaseMemRefType>
689  SmallVector<Value> invocationStack;
690  return getBufferType(value, options, invocationStack);
691 }
692 
693 /// Return the buffer type for a given Value (tensor) after bufferization.
694 FailureOr<BaseMemRefType>
696  SmallVector<Value> &invocationStack) {
697  assert(llvm::isa<TensorType>(value.getType()) &&
698  "unexpected non-tensor type");
699  invocationStack.push_back(value);
700  auto popFromStack =
701  llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
702 
703  // Try querying BufferizableOpInterface.
704  Operation *op = getOwnerOfValue(value);
705  auto bufferizableOp = options.dynCastBufferizableOp(op);
706  if (bufferizableOp)
707  return bufferizableOp.getBufferType(value, options, invocationStack);
708 
709  // Op is not bufferizable.
710  auto memSpace =
711  options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
712  if (!memSpace.has_value())
713  return op->emitError("could not infer memory space");
714 
715  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
716 }
717 
719  if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
720  return bufferizableOp.hasTensorSemantics();
722 }
723 
725  Operation *op,
726  ValueRange values) {
727  assert(values.size() == op->getNumResults() &&
728  "expected one value per OpResult");
729  OpBuilder::InsertionGuard g(rewriter);
730 
731  // Replace all OpResults with the given values.
732  SmallVector<Value> replacements;
733  for (OpResult opResult : op->getOpResults()) {
734  Value replacement = values[opResult.getResultNumber()];
735  if (llvm::isa<TensorType>(opResult.getType())) {
736  // The OpResult is a tensor. Such values are replaced with memrefs during
737  // bufferization.
738  assert((llvm::isa<MemRefType>(replacement.getType()) ||
739  llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
740  "tensor op result should be replaced with a memref value");
741  // The existing uses of the OpResult still expect a tensor. Insert a
742  // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
743  // loose all of its users and eventually DCE away.
744  rewriter.setInsertionPointAfter(op);
745  replacement = rewriter.create<bufferization::ToTensorOp>(
746  replacement.getLoc(), opResult.getType(), replacement);
747  }
748  replacements.push_back(replacement);
749  }
750 
751  rewriter.replaceOp(op, replacements);
752 }
753 
754 //===----------------------------------------------------------------------===//
755 // Bufferization-specific scoped alloc insertion support.
756 //===----------------------------------------------------------------------===//
757 
758 /// Create a memref allocation with the given type and dynamic extents.
760  MemRefType type,
761  ValueRange dynShape) const {
762  if (allocationFn)
763  return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
764 
765  // Default bufferallocation via AllocOp.
766  if (bufferAlignment != 0)
767  return b
768  .create<memref::AllocOp>(loc, type, dynShape,
770  .getResult();
771  return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
772 }
773 
774 /// Create a memory copy between two memref buffers.
776  Value from, Value to) const {
777  if (memCpyFn)
778  return (*memCpyFn)(b, loc, from, to);
779 
780  b.create<memref::CopyOp>(loc, from, to);
781  return success();
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // Bufferization-specific IRMapping support with debugging.
786 //===----------------------------------------------------------------------===//
787 
790  MemRefLayoutAttrInterface layout,
791  Attribute memorySpace) {
792  auto tensorType = llvm::cast<TensorType>(value.getType());
793 
794  // Case 1: Unranked memref type.
795  if (auto unrankedTensorType =
796  llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
797  assert(!layout && "UnrankedTensorType cannot have a layout map");
798  return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
799  memorySpace);
800  }
801 
802  // Case 2: Ranked memref type with specified layout.
803  auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
804  if (layout) {
805  return MemRefType::get(rankedTensorType.getShape(),
806  rankedTensorType.getElementType(), layout,
807  memorySpace);
808  }
809 
810  return options.unknownTypeConverterFn(value, memorySpace, options);
811 }
812 
815  Attribute memorySpace) {
816  // Case 1: Unranked memref type.
817  if (auto unrankedTensorType =
818  llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
819  return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
820  memorySpace);
821  }
822 
823  // Case 2: Ranked memref type.
824  auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
825  int64_t dynamicOffset = ShapedType::kDynamic;
826  SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
827  ShapedType::kDynamic);
828  auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
829  dynamicOffset, dynamicStrides);
830  return MemRefType::get(rankedTensorType.getShape(),
831  rankedTensorType.getElementType(), stridedLayout,
832  memorySpace);
833 }
834 
835 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
836 /// the given tensor type is unranked, return an unranked MemRef type.
839  Attribute memorySpace) {
840  // Case 1: Unranked memref type.
841  if (auto unrankedTensorType =
842  llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
843  return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
844  memorySpace);
845  }
846 
847  // Case 2: Ranked memref type.
848  auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
849  MemRefLayoutAttrInterface layout = {};
850  return MemRefType::get(rankedTensorType.getShape(),
851  rankedTensorType.getElementType(), layout,
852  memorySpace);
853 }
854 
855 //===----------------------------------------------------------------------===//
856 // Default implementations of interface methods
857 //===----------------------------------------------------------------------===//
858 
860  OpResult opResult, const AnalysisState &state) {
861  auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
862  AliasingOpOperandList opOperands =
863  bufferizableOp.getAliasingOpOperands(opResult, state);
864 
865  // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
866  // memory writes.
867  if (opOperands.getAliases().empty())
868  return true;
869 
870  // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
871  // may bufferize to a memory write.
872  if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) {
873  return state.bufferizesToMemoryWrite(*alias.opOperand);
874  }))
875  return true;
876 
877  // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
878  // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
879  // case, the OpResult bufferizes to a memory write. E.g.:
880  //
881  // %0 = "some_writing_op" : tensor<?xf32>
882  // %r = scf.if ... -> tensor<?xf32> {
883  // scf.yield %0 : tensor<?xf32>
884  // } else {
885  // %1 = "another_writing_op"(%0) : tensor<?xf32>
886  // scf.yield %1 : tensor<?xf32>
887  // }
888  // "some_reading_op"(%r)
889  //
890  // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
891  // bufferizes to a memory write and the defining op is inside the scf.if.
892  //
893  // Note: This treatment of surrouding ops is useful for ops that have a
894  // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
895  // the analysis considerably.
896  //
897  // "another_writing_op" in the above example should be able to bufferize
898  // inplace in the absence of another read of %0. However, if the scf.if op
899  // would not be considered a "write", the analysis would detect the
900  // following conflict:
901  //
902  // * read = some_reading_op
903  // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
904  // * conflictingWrite = %1
905  //
906  auto isMemoryWriteInsideOp = [&](Value v) {
907  Operation *op = getOwnerOfValue(v);
908  if (!opResult.getDefiningOp()->isAncestor(op))
909  return false;
910  return state.bufferizesToMemoryWrite(v);
911  };
913  config.alwaysIncludeLeaves = false;
914  for (AliasingOpOperand alias : opOperands) {
915  if (!state
916  .findValueInReverseUseDefChain(alias.opOperand,
917  isMemoryWriteInsideOp, config)
918  .empty())
919  return true;
920  }
921  return false;
922 }
923 
924 // Compute the AliasingOpOperandList for a given Value based on
925 // getAliasingValues.
927  Value value, const AnalysisState &state) {
928  Operation *op = getOwnerOfValue(value);
930  for (OpOperand &opOperand : op->getOpOperands()) {
931  if (!llvm::isa<TensorType>(opOperand.get().getType()))
932  continue;
933  AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
934  for (const auto &it : aliasingValues)
935  if (it.value == value)
936  result.emplace_back(&opOperand, it.relation, it.isDefinite);
937  }
938  return AliasingOpOperandList(std::move(result));
939 }
940 
942  Value value, const BufferizationOptions &options,
943  SmallVector<Value> &invocationStack) {
944  assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
945 
946  // No further analysis is possible for a block argument.
947  if (llvm::isa<BlockArgument>(value))
948  return bufferization::getMemRefType(value, options);
949 
950  // Value is an OpResult.
951  Operation *op = getOwnerOfValue(value);
952  auto opResult = llvm::cast<OpResult>(value);
953  AnalysisState state(options);
954  AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
955  if (aliases.getNumAliases() > 0 &&
956  aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
957  // If the OpResult has an equivalent OpOperand, both OpResult and
958  // OpOperand bufferize to the exact same buffer type.
959  Value equivalentOperand = aliases.getAliases().front().opOperand->get();
960  return getBufferType(equivalentOperand, options, invocationStack);
961  }
962 
963  // If we do not know the memory space and there is no default memory space,
964  // report a failure.
965  auto memSpace =
966  options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
967  if (!memSpace.has_value())
968  return op->emitError("could not infer memory space");
969 
970  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
971 }
972 
974  BufferizableOpInterface bufferizableOp, unsigned index) {
975  assert(index < bufferizableOp->getNumRegions() && "invalid region index");
976  auto regionInterface =
977  dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
978  if (!regionInterface)
979  return false;
980  return regionInterface.isRepetitiveRegion(index);
981 }
982 
985  // TODO: Take into account successor blocks.
986  // No aliasing in case of non-entry blocks.
987  if (auto bbArg = dyn_cast<BlockArgument>(value))
988  if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
989  return {};
990 
991  // Unknown op: Conservatively assume that each OpResult may alias with every
992  // OpOperand. In addition, each block argument of an entry block may alias
993  // with every OpOperand.
995  for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
996  if (isa<TensorType>(operand.get().getType()))
997  r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
998  return r;
999 }
1000 
1003  // TODO: Take into account successor blocks.
1004  // Unknown op: Conservatively assume that each OpResult may alias with every
1005  // OpOperand. In addition, each block argument of an entry block may alias
1006  // with every OpOperand.
1008  for (OpResult result : opOperand.getOwner()->getOpResults())
1009  if (llvm::isa<TensorType>(result.getType()))
1010  r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
1011  for (Region &region : opOperand.getOwner()->getRegions())
1012  if (!region.getBlocks().empty())
1013  for (BlockArgument bbArg : region.getBlocks().front().getArguments())
1014  if (isa<TensorType>(bbArg.getType()))
1015  r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
1016  return r;
1017 }
1018 
1020  auto isaTensor = [](Type t) { return isa<TensorType>(t); };
1021  bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
1022  return any_of(r.getBlocks(), [&](Block &b) {
1023  return any_of(b.getArguments(), [&](BlockArgument bbArg) {
1024  return isaTensor(bbArg.getType());
1025  });
1026  });
1027  });
1028  if (hasTensorBlockArgument)
1029  return true;
1030 
1031  if (any_of(op->getResultTypes(), isaTensor))
1032  return true;
1033  return any_of(op->getOperandTypes(), isaTensor);
1034 }
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType)
static void setInsertionPointAfter(OpBuilder &b, Value value)
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:323
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
result_range getOpResults()
Definition: Operation.h:420
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
AnalysisState provides a variety of helper functions for dealing with tensor values.
bool isValueRead(Value value) const
Return true if the given value is read by an op that bufferizes to a memory read.
AliasingValueList getAliasingValues(OpOperand &opOperand) const
Determine which Value will alias with opOperand if the op is bufferized in place.
SetVector< Value > findValueInReverseUseDefChain(OpOperand *opOperand, llvm::function_ref< bool(Value)> condition, TraversalConfig config=TraversalConfig(), llvm::DenseSet< OpOperand * > *visitedOpOperands=nullptr) const
Starting from opOperand, follow the use-def chain in reverse, always selecting the aliasing OpOperand...
virtual bool areAliasingBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 may bufferize to aliasing buffers.
virtual bool hasUndefinedContents(OpOperand *opOperand) const
Return true if the given tensor has undefined contents.
bool insideMutuallyExclusiveRegions(Operation *op0, Operation *op1)
Checks whether op0 and op1 are inside mutually exclusive regions.
bool canOmitTensorCopy(OpOperand &opOperand) const
Return true if a copy can always be avoided when allocating a new tensor for the given OpOperand.
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
virtual bool isInPlace(OpOperand &opOperand) const
Return true if the given OpResult has been decided to bufferize inplace.
SetVector< Value > findDefinitions(OpOperand *opOperand) const
Find the values that may define the contents of the given value at runtime.
bool bufferizesToAliasOnly(OpOperand &opOperand) const
Return true if opOperand does neither read nor write but bufferizes to an alias.
AliasingOpOperandList getAliasingOpOperands(Value value) const
Determine which OpOperand* will alias with value if the op is bufferized in place.
AnalysisState(const BufferizationOptions &options)
Region * getEnclosingRepetitiveRegion(Operation *op, const BufferizationOptions &options)
Return the closest enclosing repetitive region around the given op.
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
bool bufferizesToMemoryRead(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory read.
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 bufferize to equivalent buffers.
bool isOpAllowed(Operation *op) const
Return whether the op is allowed or not.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AliasingOpOperandList defaultGetAliasingOpOperands(Value value, const AnalysisState &state)
This is the default implementation of BufferizableOpInterface::getAliasingOpOperands.
bool defaultResultBufferizesToMemoryWrite(OpResult opResult, const AnalysisState &state)
This is the default implementation of BufferizableOpInterface::resultBufferizesToMemoryWrite.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, unsigned index)
This is the default implementation of BufferizableOpInterface::isRepetitiveRegion.
AliasingOpOperandList unknownGetAliasingOpOperands(Value value)
This is the default implementation of getAliasingOpOperands in case the defining op does not implemen...
bool defaultHasTensorSemantics(Operation *op)
This is the default implementation of BufferizableOpInterface::hasTensorSemantics.
FailureOr< BaseMemRefType > defaultGetBufferType(Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack)
This is the default implementation of BufferizableOpInterface::getBufferType.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Region * getParallelRegion(Region *region, const BufferizationOptions &options)
If region is a parallel region, return region.
Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)
Assuming that the given region is repetitive, find the next enclosing repetitive region.
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
const FrozenRewritePatternSet GreedyRewriteConfig config
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
std::function< void(AnalysisState &)> AnalysisStateInitFn
Initializer function for analysis state.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
bool inferFunctionResultLayout
If true, function result types are inferred from the body of the function.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
FunctionArgTypeConverterFn functionArgTypeConverterFn
Type converter from tensors to memrefs.
std::optional< AllocationFn > allocationFn
Helper functions for allocation and memory copying.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
bool isOpAllowed(Operation *op) const
Return true if the given op should be bufferized.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
FailureOr< Value > createAlloc(OpBuilder &b, Location loc, MemRefType type, ValueRange dynShape) const
Create a memref allocation with the given type and dynamic extents.
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const
Creates a memcpy between two given buffers.
SmallVector< AnalysisStateInitFn > stateInitializers
Initializer functions for analysis state.
Traversal parameters for findValueInReverseUseDefChain.