MLIR  19.0.0git
Bufferize.cpp
Go to the documentation of this file.
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Operation.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/Passes.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace bufferization {
29 #define GEN_PASS_DEF_FINALIZINGBUFFERIZE
30 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
31 #define GEN_PASS_DEF_ONESHOTBUFFERIZE
32 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
33 } // namespace bufferization
34 } // namespace mlir
35 
36 #define DEBUG_TYPE "bufferize"
37 
38 using namespace mlir;
39 using namespace mlir::bufferization;
40 
41 //===----------------------------------------------------------------------===//
42 // BufferizeTypeConverter
43 //===----------------------------------------------------------------------===//
44 
46  ValueRange inputs, Location loc) {
47  assert(inputs.size() == 1);
48  assert(isa<BaseMemRefType>(inputs[0].getType()));
49  return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
50 }
51 
52 /// Registers conversions into BufferizeTypeConverter
54  // Keep all types unchanged.
55  addConversion([](Type type) { return type; });
56  // Convert RankedTensorType to MemRefType.
57  addConversion([](RankedTensorType type) -> Type {
58  return MemRefType::get(type.getShape(), type.getElementType());
59  });
60  // Convert UnrankedTensorType to UnrankedMemRefType.
61  addConversion([](UnrankedTensorType type) -> Type {
62  return UnrankedMemRefType::get(type.getElementType(), 0);
63  });
67  ValueRange inputs, Location loc) -> Value {
68  assert(inputs.size() == 1 && "expected exactly one input");
69 
70  if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
71  // MemRef to MemRef cast.
72  assert(inputType != type && "expected different types");
73  // Unranked to ranked and ranked to unranked casts must be explicit.
74  auto rankedDestType = dyn_cast<MemRefType>(type);
75  if (!rankedDestType)
76  return nullptr;
78  options.bufferAlignment = 0;
79  FailureOr<Value> replacement =
80  castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options);
81  if (failed(replacement))
82  return nullptr;
83  return *replacement;
84  }
85 
86  if (isa<TensorType>(inputs[0].getType())) {
87  // Tensor to MemRef cast.
88  return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
89  }
90 
91  llvm_unreachable("only tensor/memref input types supported");
92  });
93 }
94 
96  ConversionTarget &target) {
97  target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
98 }
99 
100 namespace {
101 // In a finalizing bufferize conversion, we know that all tensors have been
102 // converted to memrefs, thus, this op becomes an identity.
103 class BufferizeToTensorOp
104  : public OpConversionPattern<bufferization::ToTensorOp> {
105 public:
108  matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
109  ConversionPatternRewriter &rewriter) const override {
110  rewriter.replaceOp(op, adaptor.getMemref());
111  return success();
112  }
113 };
114 } // namespace
115 
116 namespace {
117 // In a finalizing bufferize conversion, we know that all tensors have been
118 // converted to memrefs, thus, this op becomes an identity.
119 class BufferizeToMemrefOp
120  : public OpConversionPattern<bufferization::ToMemrefOp> {
121 public:
124  matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
125  ConversionPatternRewriter &rewriter) const override {
126  rewriter.replaceOp(op, adaptor.getTensor());
127  return success();
128  }
129 };
130 } // namespace
131 
133  BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
134  patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
135  patterns.getContext());
136 }
137 
138 namespace {
139 struct FinalizingBufferizePass
140  : public bufferization::impl::FinalizingBufferizeBase<
141  FinalizingBufferizePass> {
142  using FinalizingBufferizeBase<
143  FinalizingBufferizePass>::FinalizingBufferizeBase;
144 
145  void runOnOperation() override {
146  auto func = getOperation();
147  auto *context = &getContext();
148 
149  BufferizeTypeConverter typeConverter;
150  RewritePatternSet patterns(context);
151  ConversionTarget target(*context);
152 
154 
155  // If all result types are legal, and all block arguments are legal (ensured
156  // by func conversion above), then all types in the program are legal.
157  //
158  // We also check that the operand types are legal to avoid creating invalid
159  // IR. For example, this prevents
160  // populateEliminateBufferizeMaterializationsPatterns from updating the
161  // types of the operands to a return op without updating the enclosing
162  // function.
163  target.markUnknownOpDynamicallyLegal(
164  [&](Operation *op) { return typeConverter.isLegal(op); });
165 
166  if (failed(applyFullConversion(func, target, std::move(patterns))))
167  signalPassFailure();
168  }
169 };
170 
171 static LayoutMapOption parseLayoutMapOption(const std::string &s) {
172  if (s == "fully-dynamic-layout-map")
173  return LayoutMapOption::FullyDynamicLayoutMap;
174  if (s == "identity-layout-map")
175  return LayoutMapOption::IdentityLayoutMap;
176  if (s == "infer-layout-map")
177  return LayoutMapOption::InferLayoutMap;
178  llvm_unreachable("invalid layout map option");
179 }
180 
182 parseHeuristicOption(const std::string &s) {
183  if (s == "bottom-up")
185  if (s == "top-down")
187  if (s == "bottom-up-from-terminators")
190  if (s == "fuzzer")
192  llvm_unreachable("invalid analysisheuristic option");
193 }
194 
195 struct OneShotBufferizePass
196  : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
197  OneShotBufferizePass() = default;
198 
199  explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
200  : options(options) {}
201 
202  void getDependentDialects(DialectRegistry &registry) const override {
203  registry
204  .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
205  }
206 
207  void runOnOperation() override {
209  if (!options) {
210  // Make new bufferization options if none were provided when creating the
211  // pass.
212  opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
213  opt.allowUnknownOps = allowUnknownOps;
214  opt.analysisFuzzerSeed = analysisFuzzerSeed;
215  opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
216  opt.copyBeforeWrite = copyBeforeWrite;
217  opt.dumpAliasSets = dumpAliasSets;
219  parseLayoutMapOption(functionBoundaryTypeConversion));
220  if (mustInferMemorySpace) {
222  [](TensorType t) -> std::optional<Attribute> {
223  return std::nullopt;
224  };
225  }
226  opt.printConflicts = printConflicts;
227  opt.testAnalysisOnly = testAnalysisOnly;
228  opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
229  opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
230 
231  // Configure type converter.
232  LayoutMapOption unknownTypeConversionOption =
233  parseLayoutMapOption(unknownTypeConversion);
234  if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
236  "Invalid option: 'infer-layout-map' is not a valid value for "
237  "'unknown-type-conversion'");
238  return signalPassFailure();
239  }
240  opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
241  const BufferizationOptions &options) {
242  auto tensorType = cast<TensorType>(value.getType());
243  if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
245  tensorType, memorySpace);
246  assert(unknownTypeConversionOption ==
247  LayoutMapOption::FullyDynamicLayoutMap &&
248  "invalid layout map option");
250  memorySpace);
251  };
252 
253  // Configure op filter.
254  OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
255  // Filter may be specified via options.
256  if (this->dialectFilter.hasValue())
257  return llvm::is_contained(this->dialectFilter,
258  op->getDialect()->getNamespace());
259  // No filter specified: All other ops are allowed.
260  return true;
261  };
262  opt.opFilter.allowOperation(filterFn);
263  } else {
264  opt = *options;
265  }
266 
267  if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
268  // These two flags do not make sense together: "copy-before-write"
269  // indicates that copies should be inserted before every memory write,
270  // but "test-analysis-only" indicates that only the analysis should be
271  // tested. (I.e., no IR is bufferized.)
273  "Invalid option: 'copy-before-write' cannot be used with "
274  "'test-analysis-only'");
275  return signalPassFailure();
276  }
277 
278  if (opt.printConflicts && !opt.testAnalysisOnly) {
279  emitError(
281  "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
282  return signalPassFailure();
283  }
284 
285  if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
286  emitError(
288  "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
289  return signalPassFailure();
290  }
291 
292  BufferizationStatistics statistics;
293  ModuleOp moduleOp = getOperation();
294  if (opt.bufferizeFunctionBoundaries) {
295  if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
296  signalPassFailure();
297  return;
298  }
299  } else {
300  if (!opt.noAnalysisFuncFilter.empty()) {
302  "Invalid option: 'no-analysis-func-filter' requires "
303  "'bufferize-function-boundaries'");
304  return signalPassFailure();
305  }
306  if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
307  signalPassFailure();
308  return;
309  }
310  }
311 
312  // Set pass statistics.
313  this->numBufferAlloc = statistics.numBufferAlloc;
314  this->numTensorInPlace = statistics.numTensorInPlace;
315  this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
316  }
317 
318 private:
319  std::optional<OneShotBufferizationOptions> options;
320 };
321 } // namespace
322 
323 namespace {
324 struct BufferizationBufferizePass
325  : public bufferization::impl::BufferizationBufferizeBase<
326  BufferizationBufferizePass> {
327  void runOnOperation() override {
329  options.opFilter.allowDialect<BufferizationDialect>();
330 
331  if (failed(bufferizeOp(getOperation(), options)))
332  signalPassFailure();
333  }
334 
335  void getDependentDialects(DialectRegistry &registry) const override {
336  registry
337  .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
338  }
339 };
340 } // namespace
341 
343  return std::make_unique<BufferizationBufferizePass>();
344 }
345 
347  return std::make_unique<OneShotBufferizePass>();
348 }
349 
352  return std::make_unique<OneShotBufferizePass>(options);
353 }
354 
355 std::unique_ptr<OperationPass<func::FuncOp>>
357  return std::make_unique<FinalizingBufferizePass>();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // BufferizableOpInterface-based Bufferization
362 //===----------------------------------------------------------------------===//
363 
364 namespace {
365 /// A rewriter that keeps track of extra information during bufferization.
366 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
367 public:
368  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
369  DenseSet<Operation *> &toMemrefOps,
370  SmallVector<Operation *> &worklist,
372  BufferizationStatistics *statistics)
373  : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
374  worklist(worklist), analysisState(options), statistics(statistics) {
375  setListener(this);
376  }
377 
378 protected:
379  void notifyOperationErased(Operation *op) override {
380  erasedOps.insert(op);
381  // Erase if present.
382  toMemrefOps.erase(op);
383  }
384 
385  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
386  // We only care about newly created ops.
387  if (previous.isSet())
388  return;
389 
390  erasedOps.erase(op);
391 
392  // Gather statistics about allocs.
393  if (statistics) {
394  if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
395  statistics->numBufferAlloc += static_cast<int64_t>(
396  sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
397  }
398 
399  // Keep track of to_memref ops.
400  if (isa<ToMemrefOp>(op)) {
401  toMemrefOps.insert(op);
402  return;
403  }
404 
405  // Skip to_tensor ops.
406  if (isa<ToTensorOp>(op))
407  return;
408 
409  // Skip non-tensor ops.
410  if (!hasTensorSemantics(op))
411  return;
412 
413  // Skip ops that are not allowed to be bufferized.
414  auto const &options = analysisState.getOptions();
415  if (!options.isOpAllowed(op))
416  return;
417 
418  // Add op to worklist.
419  worklist.push_back(op);
420  }
421 
422 private:
423  /// A set of all erased ops.
424  DenseSet<Operation *> &erasedOps;
425 
426  /// A set of all to_memref ops.
427  DenseSet<Operation *> &toMemrefOps;
428 
429  /// The worklist of ops to be bufferized.
430  SmallVector<Operation *> &worklist;
431 
432  /// The analysis state. Used for debug assertions and access to the
433  /// bufferization options.
434  const AnalysisState analysisState;
435 
436  /// Bufferization statistics for debugging.
437  BufferizationStatistics *statistics;
438 };
439 } // namespace
440 
443  BufferizationStatistics *statistics) {
444  if (options.copyBeforeWrite) {
445  AnalysisState state(options);
446  if (failed(insertTensorCopies(op, state)))
447  return failure();
448  }
449 
450  // Keep track of to_memref ops.
451  DenseSet<Operation *> toMemrefOps;
452  op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
453 
454  // Gather all bufferizable ops in top-to-bottom order.
455  //
456  // We should ideally know the exact memref type of all operands when
457  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
458  // Otherwise, we have to use a memref type with a fully dynamic layout map to
459  // avoid copies. We are currently missing patterns for layout maps to
460  // canonicalize away (or canonicalize to more precise layouts).
461  SmallVector<Operation *> worklist;
462  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
463  if (options.isOpAllowed(op) && hasTensorSemantics(op))
464  worklist.push_back(op);
465  });
466 
467  // Keep track of all erased ops.
468  DenseSet<Operation *> erasedOps;
469 
470  // Bufferize all ops.
471  BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
472  worklist, options, statistics);
473  for (unsigned i = 0; i < worklist.size(); ++i) {
474  Operation *nextOp = worklist[i];
475  // Skip ops that were erased.
476  if (erasedOps.contains(nextOp))
477  continue;
478  // Skip ops that are not bufferizable or not allowed.
479  auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
480  if (!bufferizableOp)
481  continue;
482  // Skip ops that no longer have tensor semantics.
483  if (!hasTensorSemantics(nextOp))
484  continue;
485  // Check for unsupported unstructured control flow.
486  if (!bufferizableOp.supportsUnstructuredControlFlow())
487  for (Region &r : nextOp->getRegions())
488  if (r.getBlocks().size() > 1)
489  return nextOp->emitOpError(
490  "op or BufferizableOpInterface implementation does not support "
491  "unstructured control flow, but at least one region has multiple "
492  "blocks");
493 
494  // Bufferize the op.
495  LLVM_DEBUG(llvm::dbgs()
496  << "//===-------------------------------------------===//\n"
497  << "IR after bufferizing: " << nextOp->getName() << "\n");
498  rewriter.setInsertionPoint(nextOp);
499  if (failed(bufferizableOp.bufferize(rewriter, options))) {
500  LLVM_DEBUG(llvm::dbgs()
501  << "failed to bufferize\n"
502  << "//===-------------------------------------------===//\n");
503  return nextOp->emitError("failed to bufferize op");
504  }
505  LLVM_DEBUG(llvm::dbgs()
506  << *op
507  << "\n//===-------------------------------------------===//\n");
508  }
509 
510  // Return early if the top-level op is entirely gone.
511  if (erasedOps.contains(op))
512  return success();
513 
514  // Fold all to_memref(to_tensor(x)) pairs.
515  for (Operation *op : toMemrefOps) {
516  rewriter.setInsertionPoint(op);
518  rewriter, cast<ToMemrefOp>(op), options);
519  }
520 
521  // Remove all dead to_tensor ops.
522  op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
523  if (toTensorOp->getUses().empty()) {
524  rewriter.eraseOp(toTensorOp);
525  return WalkResult::skip();
526  }
527  return WalkResult::advance();
528  });
529 
530  /// Check the result of bufferization. Return an error if an op was not
531  /// bufferized, unless partial bufferization is allowed.
532  if (options.allowUnknownOps)
533  return success();
534 
535  for (Operation *op : worklist) {
536  // Skip ops that are entirely gone.
537  if (erasedOps.contains(op))
538  continue;
539  // Ops that no longer have tensor semantics (because they were updated
540  // in-place) are allowed.
541  if (!hasTensorSemantics(op))
542  continue;
543  // Continue ops that are not allowed.
544  if (!options.isOpAllowed(op))
545  continue;
546  // Ops without any uses and no side effects will fold away.
547  if (op->getUses().empty() && isMemoryEffectFree(op))
548  continue;
549  // ToTensorOps/ToMemrefOps are allowed in the output.
550  if (isa<ToTensorOp, ToMemrefOp>(op))
551  continue;
552  return op->emitError("op was not bufferized");
553  }
554 
555  return success();
556 }
557 
560  const BufferizationOptions &options) {
561  OpBuilder::InsertionGuard g(rewriter);
562  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
563  if (!bufferizableOp)
564  return failure();
565 
566  // Compute the new signature.
567  SmallVector<Type> newTypes;
568  for (BlockArgument &bbArg : block->getArguments()) {
569  auto tensorType = dyn_cast<TensorType>(bbArg.getType());
570  if (!tensorType) {
571  newTypes.push_back(bbArg.getType());
572  continue;
573  }
574 
575  FailureOr<BaseMemRefType> memrefType =
577  if (failed(memrefType))
578  return failure();
579  newTypes.push_back(*memrefType);
580  }
581 
582  // Change the type of all block arguments.
583  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
584  if (bbArg.getType() == type)
585  continue;
586 
587  // Collect all uses of the bbArg.
588  SmallVector<OpOperand *> bbArgUses;
589  for (OpOperand &use : bbArg.getUses())
590  bbArgUses.push_back(&use);
591 
592  // Change the bbArg type to memref.
593  bbArg.setType(type);
594 
595  // Replace all uses of the original tensor bbArg.
596  rewriter.setInsertionPointToStart(block);
597  if (!bbArgUses.empty()) {
598  Value toTensorOp =
599  rewriter.create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
600  for (OpOperand *use : bbArgUses)
601  use->set(toTensorOp);
602  }
603  }
604 
605  // Bufferize callers of the block.
606  for (Operation *op : block->getUsers()) {
607  auto branchOp = dyn_cast<BranchOpInterface>(op);
608  if (!branchOp)
609  return op->emitOpError("cannot bufferize ops with block references that "
610  "do not implement BranchOpInterface");
611 
612  auto it = llvm::find(op->getSuccessors(), block);
613  assert(it != op->getSuccessors().end() && "could find successor");
614  int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
615 
616  SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
617  SmallVector<Value> newOperands;
618  for (auto [operand, type] :
619  llvm::zip(operands.getForwardedOperands(), newTypes)) {
620  if (operand.getType() == type) {
621  // Not a tensor type. Nothing to do for this operand.
622  newOperands.push_back(operand);
623  continue;
624  }
625  FailureOr<BaseMemRefType> operandBufferType =
627  if (failed(operandBufferType))
628  return failure();
629  rewriter.setInsertionPointAfterValue(operand);
630  Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
631  operand.getLoc(), *operandBufferType, operand);
632  // A cast is needed if the operand and the block argument have different
633  // bufferized types.
634  if (type != *operandBufferType)
635  bufferizedOperand = rewriter.create<memref::CastOp>(
636  operand.getLoc(), type, bufferizedOperand);
637  newOperands.push_back(bufferizedOperand);
638  }
639  operands.getMutableForwardedOperands().assign(newOperands);
640  }
641 
642  return success();
643 }
644 
647  options.allowUnknownOps = true;
648  options.copyBeforeWrite = true;
649  options.enforceAliasingInvariants = false;
650  options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
651  const BufferizationOptions &options) {
653  cast<TensorType>(value.getType()), memorySpace);
654  };
655  options.opFilter.allowDialect<BufferizationDialect>();
656  return options;
657 }
static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc)
Definition: Bufferize.cpp:45
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:84
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
StringRef getNamespace() const
Definition: Dialect.h:57
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
SuccessorRange getSuccessors()
Definition: Operation.h:699
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
AnalysisState provides a variety of helper functions for dealing with tensor values.
A helper type converter class that automatically populates the relevant materializations and type con...
Definition: Bufferize.h:43
BufferizeTypeConverter()
Registers conversions into BufferizeTypeConverter.
Definition: Bufferize.cpp:53
void allowOperation()
Allow the given ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
void populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns)
Populate patterns to eliminate bufferize materializations.
Definition: Bufferize.cpp:132
LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
void populateBufferizeMaterializationLegality(ConversionTarget &target)
Marks ops used by bufferization for type conversion materializations as "legal" in the given Conversi...
Definition: Bufferize.cpp:95
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:441
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
std::unique_ptr< OperationPass< func::FuncOp > > createFinalizingBufferizePass()
Creates a pass that finalizes a partial bufferization by removing remaining bufferization....
Definition: Bufferize.cpp:356
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:559
std::unique_ptr< Pass > createOneShotBufferizePass()
Create a pass that bufferizes all ops that implement BufferizableOpInterface with One-Shot Bufferize.
Definition: Bufferize.cpp:346
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BufferizationOptions getPartialBufferizationOptions()
Return BufferizationOptions such that the bufferizeOp behaves like the old (deprecated) partial,...
Definition: Bufferize.cpp:645
std::unique_ptr< Pass > createBufferizationBufferizePass()
Create a pass that bufferizes ops from the bufferization dialect.
Definition: Bufferize.cpp:342
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.