MLIR  18.0.0git
Bufferize.cpp
Go to the documentation of this file.
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Operation.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/Passes.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace bufferization {
29 #define GEN_PASS_DEF_FINALIZINGBUFFERIZE
30 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
31 #define GEN_PASS_DEF_ONESHOTBUFFERIZE
32 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
33 } // namespace bufferization
34 } // namespace mlir
35 
36 #define DEBUG_TYPE "bufferize"
37 
38 using namespace mlir;
39 using namespace mlir::bufferization;
40 
41 //===----------------------------------------------------------------------===//
42 // BufferizeTypeConverter
43 //===----------------------------------------------------------------------===//
44 
46  ValueRange inputs, Location loc) {
47  assert(inputs.size() == 1);
48  assert(isa<BaseMemRefType>(inputs[0].getType()));
49  return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
50 }
51 
52 /// Registers conversions into BufferizeTypeConverter
54  // Keep all types unchanged.
55  addConversion([](Type type) { return type; });
56  // Convert RankedTensorType to MemRefType.
57  addConversion([](RankedTensorType type) -> Type {
58  return MemRefType::get(type.getShape(), type.getElementType());
59  });
60  // Convert UnrankedTensorType to UnrankedMemRefType.
61  addConversion([](UnrankedTensorType type) -> Type {
62  return UnrankedMemRefType::get(type.getElementType(), 0);
63  });
67  ValueRange inputs, Location loc) -> Value {
68  assert(inputs.size() == 1 && "expected exactly one input");
69 
70  if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
71  // MemRef to MemRef cast.
72  assert(inputType != type && "expected different types");
73  // Unranked to ranked and ranked to unranked casts must be explicit.
74  auto rankedDestType = dyn_cast<MemRefType>(type);
75  if (!rankedDestType)
76  return nullptr;
77  FailureOr<Value> replacement =
78  castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
79  if (failed(replacement))
80  return nullptr;
81  return *replacement;
82  }
83 
84  if (isa<TensorType>(inputs[0].getType())) {
85  // Tensor to MemRef cast.
86  return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
87  }
88 
89  llvm_unreachable("only tensor/memref input types supported");
90  });
91 }
92 
94  ConversionTarget &target) {
95  target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
96 }
97 
98 namespace {
99 // In a finalizing bufferize conversion, we know that all tensors have been
100 // converted to memrefs, thus, this op becomes an identity.
101 class BufferizeToTensorOp
102  : public OpConversionPattern<bufferization::ToTensorOp> {
103 public:
106  matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
107  ConversionPatternRewriter &rewriter) const override {
108  rewriter.replaceOp(op, adaptor.getMemref());
109  return success();
110  }
111 };
112 } // namespace
113 
114 namespace {
115 // In a finalizing bufferize conversion, we know that all tensors have been
116 // converted to memrefs, thus, this op becomes an identity.
117 class BufferizeToMemrefOp
118  : public OpConversionPattern<bufferization::ToMemrefOp> {
119 public:
122  matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
123  ConversionPatternRewriter &rewriter) const override {
124  rewriter.replaceOp(op, adaptor.getTensor());
125  return success();
126  }
127 };
128 } // namespace
129 
131  BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
132  patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
133  patterns.getContext());
134 }
135 
136 namespace {
137 struct FinalizingBufferizePass
138  : public bufferization::impl::FinalizingBufferizeBase<
139  FinalizingBufferizePass> {
140  using FinalizingBufferizeBase<
141  FinalizingBufferizePass>::FinalizingBufferizeBase;
142 
143  void runOnOperation() override {
144  auto func = getOperation();
145  auto *context = &getContext();
146 
147  BufferizeTypeConverter typeConverter;
148  RewritePatternSet patterns(context);
149  ConversionTarget target(*context);
150 
152 
153  // If all result types are legal, and all block arguments are legal (ensured
154  // by func conversion above), then all types in the program are legal.
155  //
156  // We also check that the operand types are legal to avoid creating invalid
157  // IR. For example, this prevents
158  // populateEliminateBufferizeMaterializationsPatterns from updating the
159  // types of the operands to a return op without updating the enclosing
160  // function.
161  target.markUnknownOpDynamicallyLegal(
162  [&](Operation *op) { return typeConverter.isLegal(op); });
163 
164  if (failed(applyFullConversion(func, target, std::move(patterns))))
165  signalPassFailure();
166  }
167 };
168 
169 static LayoutMapOption parseLayoutMapOption(const std::string &s) {
170  if (s == "fully-dynamic-layout-map")
171  return LayoutMapOption::FullyDynamicLayoutMap;
172  if (s == "identity-layout-map")
173  return LayoutMapOption::IdentityLayoutMap;
174  if (s == "infer-layout-map")
175  return LayoutMapOption::InferLayoutMap;
176  llvm_unreachable("invalid layout map option");
177 }
178 
180 parseHeuristicOption(const std::string &s) {
181  if (s == "bottom-up")
183  if (s == "top-down")
185  llvm_unreachable("invalid analysisheuristic option");
186 }
187 
188 struct OneShotBufferizePass
189  : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
190  OneShotBufferizePass() = default;
191 
192  explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
193  : options(options) {}
194 
195  void getDependentDialects(DialectRegistry &registry) const override {
196  registry
197  .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
198  }
199 
200  void runOnOperation() override {
202  if (!options) {
203  // Make new bufferization options if none were provided when creating the
204  // pass.
205  opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
206  opt.allowUnknownOps = allowUnknownOps;
207  opt.analysisFuzzerSeed = analysisFuzzerSeed;
208  opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
209  opt.copyBeforeWrite = copyBeforeWrite;
210  opt.dumpAliasSets = dumpAliasSets;
212  parseLayoutMapOption(functionBoundaryTypeConversion));
213  if (mustInferMemorySpace)
214  opt.defaultMemorySpace = std::nullopt;
215  opt.printConflicts = printConflicts;
216  opt.testAnalysisOnly = testAnalysisOnly;
217  opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
218  opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
219 
220  // Configure type converter.
221  LayoutMapOption unknownTypeConversionOption =
222  parseLayoutMapOption(unknownTypeConversion);
223  if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
225  "Invalid option: 'infer-layout-map' is not a valid value for "
226  "'unknown-type-conversion'");
227  return signalPassFailure();
228  }
229  opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
230  const BufferizationOptions &options) {
231  auto tensorType = cast<TensorType>(value.getType());
232  if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
234  tensorType, memorySpace);
235  assert(unknownTypeConversionOption ==
236  LayoutMapOption::FullyDynamicLayoutMap &&
237  "invalid layout map option");
239  memorySpace);
240  };
241 
242  // Configure op filter.
243  OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
244  // Filter may be specified via options.
245  if (this->dialectFilter.hasValue())
246  return llvm::is_contained(this->dialectFilter,
247  op->getDialect()->getNamespace());
248  // No filter specified: All other ops are allowed.
249  return true;
250  };
251  opt.opFilter.allowOperation(filterFn);
252  } else {
253  opt = *options;
254  }
255 
256  if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
257  // These two flags do not make sense together: "copy-before-write"
258  // indicates that copies should be inserted before every memory write,
259  // but "test-analysis-only" indicates that only the analysis should be
260  // tested. (I.e., no IR is bufferized.)
262  "Invalid option: 'copy-before-write' cannot be used with "
263  "'test-analysis-only'");
264  return signalPassFailure();
265  }
266 
267  if (opt.printConflicts && !opt.testAnalysisOnly) {
268  emitError(
270  "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
271  return signalPassFailure();
272  }
273 
274  if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
275  emitError(
277  "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
278  return signalPassFailure();
279  }
280 
281  BufferizationStatistics statistics;
282  ModuleOp moduleOp = getOperation();
283  if (opt.bufferizeFunctionBoundaries) {
284  if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
285  signalPassFailure();
286  return;
287  }
288  } else {
289  if (!opt.noAnalysisFuncFilter.empty()) {
291  "Invalid option: 'no-analysis-func-filter' requires "
292  "'bufferize-function-boundaries'");
293  return signalPassFailure();
294  }
295  if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
296  signalPassFailure();
297  return;
298  }
299  }
300 
301  // Set pass statistics.
302  this->numBufferAlloc = statistics.numBufferAlloc;
303  this->numTensorInPlace = statistics.numTensorInPlace;
304  this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
305  }
306 
307 private:
308  std::optional<OneShotBufferizationOptions> options;
309 };
310 } // namespace
311 
312 namespace {
313 struct BufferizationBufferizePass
314  : public bufferization::impl::BufferizationBufferizeBase<
315  BufferizationBufferizePass> {
316  void runOnOperation() override {
318  options.opFilter.allowDialect<BufferizationDialect>();
319 
320  if (failed(bufferizeOp(getOperation(), options)))
321  signalPassFailure();
322  }
323 
324  void getDependentDialects(DialectRegistry &registry) const override {
325  registry
326  .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
327  }
328 };
329 } // namespace
330 
332  return std::make_unique<BufferizationBufferizePass>();
333 }
334 
336  return std::make_unique<OneShotBufferizePass>();
337 }
338 
341  return std::make_unique<OneShotBufferizePass>(options);
342 }
343 
344 std::unique_ptr<OperationPass<func::FuncOp>>
346  return std::make_unique<FinalizingBufferizePass>();
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // BufferizableOpInterface-based Bufferization
351 //===----------------------------------------------------------------------===//
352 
353 static bool isaTensor(Type t) { return isa<TensorType>(t); }
354 
355 /// Return true if the given op has a tensor result or a tensor operand.
356 static bool hasTensorSemantics(Operation *op) {
357  bool hasTensorBlockArgument = any_of(op->getRegions(), [](Region &r) {
358  return any_of(r.getBlocks(), [](Block &b) {
359  return any_of(b.getArguments(), [](BlockArgument bbArg) {
360  return isaTensor(bbArg.getType());
361  });
362  });
363  });
364  if (hasTensorBlockArgument)
365  return true;
366 
367  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
368  bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
369  bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
370  return hasTensorArg || hasTensorResult;
371  }
372 
373  bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
374  bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
375  return hasTensorResult || hasTensorOperand;
376 }
377 
378 namespace {
379 /// A rewriter that keeps track of extra information during bufferization.
380 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
381 public:
382  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
383  DenseSet<Operation *> &toMemrefOps,
384  SmallVector<Operation *> &worklist,
386  BufferizationStatistics *statistics)
387  : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
388  worklist(worklist), analysisState(options), statistics(statistics) {
389  setListener(this);
390  }
391 
392 protected:
393  void notifyOperationRemoved(Operation *op) override {
394  erasedOps.insert(op);
395  // Erase if present.
396  toMemrefOps.erase(op);
397  }
398 
399  void notifyOperationInserted(Operation *op) override {
400  erasedOps.erase(op);
401 
402  // Gather statistics about allocs.
403  if (statistics) {
404  if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
405  statistics->numBufferAlloc += static_cast<int64_t>(
406  sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
407  }
408 
409  // Keep track of to_memref ops.
410  if (isa<ToMemrefOp>(op)) {
411  toMemrefOps.insert(op);
412  return;
413  }
414 
415  // Skip to_tensor ops.
416  if (isa<ToTensorOp>(op))
417  return;
418 
419  // Skip non-tensor ops.
420  if (!hasTensorSemantics(op))
421  return;
422 
423  // Skip ops that are not allowed to be bufferized.
424  auto const &options = analysisState.getOptions();
425  if (!options.isOpAllowed(op))
426  return;
427 
428  // Add op to worklist.
429  worklist.push_back(op);
430  }
431 
432 private:
433  /// A set of all erased ops.
434  DenseSet<Operation *> &erasedOps;
435 
436  /// A set of all to_memref ops.
437  DenseSet<Operation *> &toMemrefOps;
438 
439  /// The worklist of ops to be bufferized.
440  SmallVector<Operation *> &worklist;
441 
442  /// The analysis state. Used for debug assertions and access to the
443  /// bufferization options.
444  const AnalysisState analysisState;
445 
446  /// Bufferization statistics for debugging.
447  BufferizationStatistics *statistics;
448 };
449 } // namespace
450 
453  BufferizationStatistics *statistics) {
454  if (options.copyBeforeWrite) {
455  AnalysisState state(options);
456  if (failed(insertTensorCopies(op, state)))
457  return failure();
458  }
459 
460  // Keep track of to_memref ops.
461  DenseSet<Operation *> toMemrefOps;
462  op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
463 
464  // Gather all bufferizable ops in top-to-bottom order.
465  //
466  // We should ideally know the exact memref type of all operands when
467  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
468  // Otherwise, we have to use a memref type with a fully dynamic layout map to
469  // avoid copies. We are currently missing patterns for layout maps to
470  // canonicalize away (or canonicalize to more precise layouts).
471  SmallVector<Operation *> worklist;
472  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
473  if (hasTensorSemantics(op))
474  worklist.push_back(op);
475  });
476 
477  // Keep track of all erased ops.
478  DenseSet<Operation *> erasedOps;
479 
480  // Bufferize all ops.
481  BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
482  worklist, options, statistics);
483  for (unsigned i = 0; i < worklist.size(); ++i) {
484  Operation *nextOp = worklist[i];
485  // Skip ops that were erased.
486  if (erasedOps.contains(nextOp))
487  continue;
488  // Skip ops that are not bufferizable or not allowed.
489  auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
490  if (!bufferizableOp)
491  continue;
492  if (!options.isOpAllowed(nextOp))
493  continue;
494  // Skip ops that no longer have tensor semantics.
495  if (!hasTensorSemantics(nextOp))
496  continue;
497  // Check for unsupported unstructured control flow.
498  if (!bufferizableOp.supportsUnstructuredControlFlow())
499  for (Region &r : nextOp->getRegions())
500  if (r.getBlocks().size() > 1)
501  return nextOp->emitOpError(
502  "op or BufferizableOpInterface implementation does not support "
503  "unstructured control flow, but at least one region has multiple "
504  "blocks");
505 
506  // Bufferize the op.
507  LLVM_DEBUG(llvm::dbgs()
508  << "//===-------------------------------------------===//\n"
509  << "IR after bufferizing: " << nextOp->getName() << "\n");
510  rewriter.setInsertionPoint(nextOp);
511  if (failed(bufferizableOp.bufferize(rewriter, options))) {
512  LLVM_DEBUG(llvm::dbgs()
513  << "failed to bufferize\n"
514  << "//===-------------------------------------------===//\n");
515  return nextOp->emitError("failed to bufferize op");
516  }
517  LLVM_DEBUG(llvm::dbgs()
518  << *op
519  << "\n//===-------------------------------------------===//\n");
520  }
521 
522  // Fold all to_memref(to_tensor(x)) pairs.
523  for (Operation *op : toMemrefOps) {
524  rewriter.setInsertionPoint(op);
526  cast<ToMemrefOp>(op));
527  }
528 
529  // Remove all dead to_tensor ops.
530  op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
531  if (toTensorOp->getUses().empty()) {
532  rewriter.eraseOp(toTensorOp);
533  return WalkResult::skip();
534  }
535  return WalkResult::advance();
536  });
537 
538  /// Check the result of bufferization. Return an error if an op was not
539  /// bufferized, unless partial bufferization is allowed.
540  if (options.allowUnknownOps)
541  return success();
542 
543  for (Operation *op : worklist) {
544  // Skip ops that are entirely gone.
545  if (erasedOps.contains(op))
546  continue;
547  // Ops that no longer have tensor semantics (because they were updated
548  // in-place) are allowed.
549  if (!hasTensorSemantics(op))
550  continue;
551  // Continue ops that are not allowed.
552  if (!options.isOpAllowed(op))
553  continue;
554  // Ops without any uses and no side effects will fold away.
555  if (op->getUses().empty() && isMemoryEffectFree(op))
556  continue;
557  // ToTensorOps/ToMemrefOps are allowed in the output.
558  if (isa<ToTensorOp, ToMemrefOp>(op))
559  continue;
560  return op->emitError("op was not bufferized");
561  }
562 
563  return success();
564 }
565 
568  const BufferizationOptions &options) {
569  OpBuilder::InsertionGuard g(rewriter);
570  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
571  if (!bufferizableOp)
572  return failure();
573 
574  // Compute the new signature.
575  SmallVector<Type> newTypes;
576  for (BlockArgument &bbArg : block->getArguments()) {
577  auto tensorType = dyn_cast<TensorType>(bbArg.getType());
578  if (!tensorType) {
579  newTypes.push_back(bbArg.getType());
580  continue;
581  }
582 
583  FailureOr<BaseMemRefType> memrefType =
585  if (failed(memrefType))
586  return failure();
587  newTypes.push_back(*memrefType);
588  }
589 
590  // Change the type of all block arguments.
591  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
592  if (bbArg.getType() == type)
593  continue;
594 
595  // Collect all uses of the bbArg.
596  SmallVector<OpOperand *> bbArgUses;
597  for (OpOperand &use : bbArg.getUses())
598  bbArgUses.push_back(&use);
599 
600  // Change the bbArg type to memref.
601  bbArg.setType(type);
602 
603  // Replace all uses of the original tensor bbArg.
604  rewriter.setInsertionPointToStart(block);
605  if (!bbArgUses.empty()) {
606  Value toTensorOp =
607  rewriter.create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
608  for (OpOperand *use : bbArgUses)
609  use->set(toTensorOp);
610  }
611  }
612 
613  // Bufferize callers of the block.
614  for (Operation *op : block->getUsers()) {
615  auto branchOp = dyn_cast<BranchOpInterface>(op);
616  if (!branchOp)
617  return op->emitOpError("cannot bufferize ops with block references that "
618  "do not implement BranchOpInterface");
619 
620  auto it = llvm::find(op->getSuccessors(), block);
621  assert(it != op->getSuccessors().end() && "could find successor");
622  int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
623 
624  SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
625  SmallVector<Value> newOperands;
626  for (auto [operand, type] :
627  llvm::zip(operands.getForwardedOperands(), newTypes)) {
628  if (operand.getType() == type) {
629  // Not a tensor type. Nothing to do for this operand.
630  newOperands.push_back(operand);
631  continue;
632  }
633  FailureOr<BaseMemRefType> operandBufferType =
635  if (failed(operandBufferType))
636  return failure();
637  rewriter.setInsertionPointAfterValue(operand);
638  Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
639  operand.getLoc(), *operandBufferType, operand);
640  // A cast is needed if the operand and the block argument have different
641  // bufferized types.
642  if (type != *operandBufferType)
643  bufferizedOperand = rewriter.create<memref::CastOp>(
644  operand.getLoc(), type, bufferizedOperand);
645  newOperands.push_back(bufferizedOperand);
646  }
647  operands.getMutableForwardedOperands().assign(newOperands);
648  }
649 
650  return success();
651 }
652 
655  options.allowUnknownOps = true;
656  options.copyBeforeWrite = true;
657  options.enforceAliasingInvariants = false;
658  options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
659  const BufferizationOptions &options) {
661  cast<TensorType>(value.getType()), memorySpace);
662  };
663  options.opFilter.allowDialect<BufferizationDialect>();
664  return options;
665 }
static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc)
Definition: Bufferize.cpp:45
static bool isaTensor(Type t)
Definition: Bufferize.cpp:353
static bool hasTensorSemantics(Operation *op)
Return true if the given op has a tensor result or a tensor operand.
Definition: Bufferize.cpp:356
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:80
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
StringRef getNamespace() const
Definition: Dialect.h:57
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
SuccessorRange getSuccessors()
Definition: Operation.h:682
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:825
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
AnalysisState provides a variety of helper functions for dealing with tensor values.
A helper type converter class that automatically populates the relevant materializations and type con...
Definition: Bufferize.h:43
BufferizeTypeConverter()
Registers conversions into BufferizeTypeConverter.
Definition: Bufferize.cpp:53
void allowOperation()
Allow the given ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
void populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns)
Populate patterns to eliminate bufferize materializations.
Definition: Bufferize.cpp:130
LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
void populateBufferizeMaterializationLegality(ConversionTarget &target)
Marks ops used by bufferization for type conversion materializations as "legal" in the given Conversi...
Definition: Bufferize.cpp:93
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref)
Try to fold to_memref(to_tensor(x)).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:451
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
std::unique_ptr< OperationPass< func::FuncOp > > createFinalizingBufferizePass()
Creates a pass that finalizes a partial bufferization by removing remaining bufferization....
Definition: Bufferize.cpp:345
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:567
std::unique_ptr< Pass > createOneShotBufferizePass()
Create a pass that bufferizes all ops that implement BufferizableOpInterface with One-Shot Bufferize.
Definition: Bufferize.cpp:335
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BufferizationOptions getPartialBufferizationOptions()
Return BufferizationOptions such that the bufferizeOp behaves like the old (deprecated) partial,...
Definition: Bufferize.cpp:653
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
std::unique_ptr< Pass > createBufferizationBufferizePass()
Create a pass that bufferizes ops from the bufferization dialect.
Definition: Bufferize.cpp:331
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
std::optional< Attribute > defaultMemorySpace
The default memory space that should be used when it cannot be inferred from the context.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.