MLIR  21.0.0git
Bufferize.cpp
Go to the documentation of this file.
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Operation.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/Passes.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace bufferization {
29 #define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferize"
35 
36 using namespace mlir;
37 using namespace mlir::bufferization;
38 
39 namespace {
40 
42 parseHeuristicOption(const std::string &s) {
43  if (s == "bottom-up")
45  if (s == "top-down")
47  if (s == "bottom-up-from-terminators")
50  if (s == "fuzzer")
52  llvm_unreachable("invalid analysisheuristic option");
53 }
54 
55 struct OneShotBufferizePass
56  : public bufferization::impl::OneShotBufferizePassBase<
57  OneShotBufferizePass> {
58  using Base::Base;
59 
60  void runOnOperation() override {
62  if (!options) {
63  // Make new bufferization options if none were provided when creating the
64  // pass.
65  opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
66  opt.allowUnknownOps = allowUnknownOps;
67  opt.analysisFuzzerSeed = analysisFuzzerSeed;
68  opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
69  opt.copyBeforeWrite = copyBeforeWrite;
70  opt.dumpAliasSets = dumpAliasSets;
71  opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
72 
73  if (mustInferMemorySpace && useEncodingForMemorySpace) {
74  emitError(getOperation()->getLoc())
75  << "only one of 'must-infer-memory-space' and "
76  "'use-encoding-for-memory-space' are allowed in "
77  << getArgument();
78  return signalPassFailure();
79  }
80 
81  if (mustInferMemorySpace) {
83  [](TensorType t) -> std::optional<Attribute> {
84  return std::nullopt;
85  };
86  }
87 
88  if (useEncodingForMemorySpace) {
90  [](TensorType t) -> std::optional<Attribute> {
91  if (auto rtt = dyn_cast<RankedTensorType>(t))
92  return rtt.getEncoding();
93  return std::nullopt;
94  };
95  }
96 
97  opt.printConflicts = printConflicts;
98  opt.bufferAlignment = bufferAlignment;
99  opt.testAnalysisOnly = testAnalysisOnly;
100  opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
101  opt.checkParallelRegions = checkParallelRegions;
102  opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
103 
104  // Configure type converter.
105  LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
106  if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
108  "Invalid option: 'infer-layout-map' is not a valid value for "
109  "'unknown-type-conversion'");
110  return signalPassFailure();
111  }
112  opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
113  const BufferizationOptions &options) {
114  auto tensorType = cast<TensorType>(value.getType());
115  if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
117  tensorType, memorySpace);
118  assert(unknownTypeConversionOption ==
119  LayoutMapOption::FullyDynamicLayoutMap &&
120  "invalid layout map option");
122  memorySpace);
123  };
124 
125  // Configure op filter.
126  OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
127  // Filter may be specified via options.
128  if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
129  return llvm::is_contained(this->dialectFilter,
130  op->getDialect()->getNamespace());
131  // No filter specified: All other ops are allowed.
132  return true;
133  };
134  opt.opFilter.allowOperation(filterFn);
135  } else {
136  opt = *options;
137  }
138 
139  if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
140  // These two flags do not make sense together: "copy-before-write"
141  // indicates that copies should be inserted before every memory write,
142  // but "test-analysis-only" indicates that only the analysis should be
143  // tested. (I.e., no IR is bufferized.)
145  "Invalid option: 'copy-before-write' cannot be used with "
146  "'test-analysis-only'");
147  return signalPassFailure();
148  }
149 
150  if (opt.printConflicts && !opt.testAnalysisOnly) {
151  emitError(
153  "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
154  return signalPassFailure();
155  }
156 
157  if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
158  emitError(
160  "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
161  return signalPassFailure();
162  }
163 
164  BufferizationStatistics statistics;
165  ModuleOp moduleOp = getOperation();
166  if (opt.bufferizeFunctionBoundaries) {
167  if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
168  signalPassFailure();
169  return;
170  }
171  } else {
172  if (!opt.noAnalysisFuncFilter.empty()) {
174  "Invalid option: 'no-analysis-func-filter' requires "
175  "'bufferize-function-boundaries'");
176  return signalPassFailure();
177  }
178  if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
179  signalPassFailure();
180  return;
181  }
182  }
183 
184  // Set pass statistics.
185  this->numBufferAlloc = statistics.numBufferAlloc;
186  this->numTensorInPlace = statistics.numTensorInPlace;
187  this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
188  }
189 
190 private:
191  std::optional<OneShotBufferizationOptions> options;
192 };
193 } // namespace
194 
195 //===----------------------------------------------------------------------===//
196 // BufferizableOpInterface-based Bufferization
197 //===----------------------------------------------------------------------===//
198 
199 namespace {
200 /// A rewriter that keeps track of extra information during bufferization.
201 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
202 public:
203  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
204  DenseSet<Operation *> &toMemrefOps,
205  SmallVector<Operation *> &worklist,
207  BufferizationStatistics *statistics)
208  : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
209  worklist(worklist), analysisState(options), statistics(statistics) {
210  setListener(this);
211  }
212 
213 protected:
214  void notifyOperationErased(Operation *op) override {
215  erasedOps.insert(op);
216  // Erase if present.
217  toMemrefOps.erase(op);
218  }
219 
220  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
221  // We only care about newly created ops.
222  if (previous.isSet())
223  return;
224 
225  erasedOps.erase(op);
226 
227  // Gather statistics about allocs.
228  if (statistics) {
229  if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
230  statistics->numBufferAlloc += static_cast<int64_t>(
231  sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
232  }
233 
234  // Keep track of to_memref ops.
235  if (isa<ToMemrefOp>(op)) {
236  toMemrefOps.insert(op);
237  return;
238  }
239 
240  // Skip to_tensor ops.
241  if (isa<ToTensorOp>(op))
242  return;
243 
244  // Skip non-tensor ops.
245  if (!hasTensorSemantics(op))
246  return;
247 
248  // Skip ops that are not allowed to be bufferized.
249  auto const &options = analysisState.getOptions();
250  if (!options.isOpAllowed(op))
251  return;
252 
253  // Add op to worklist.
254  worklist.push_back(op);
255  }
256 
257 private:
258  /// A set of all erased ops.
259  DenseSet<Operation *> &erasedOps;
260 
261  /// A set of all to_memref ops.
262  DenseSet<Operation *> &toMemrefOps;
263 
264  /// The worklist of ops to be bufferized.
265  SmallVector<Operation *> &worklist;
266 
267  /// The analysis state. Used for debug assertions and access to the
268  /// bufferization options.
269  const AnalysisState analysisState;
270 
271  /// Bufferization statistics for debugging.
272  BufferizationStatistics *statistics;
273 };
274 } // namespace
275 
278  BufferizationStatistics *statistics) {
279  if (options.copyBeforeWrite) {
280  AnalysisState state(options);
281  if (failed(insertTensorCopies(op, state)))
282  return failure();
283  }
284 
285  // Keep track of to_memref ops.
286  DenseSet<Operation *> toMemrefOps;
287  op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
288 
289  // Gather all bufferizable ops in top-to-bottom order.
290  //
291  // We should ideally know the exact memref type of all operands when
292  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
293  // Otherwise, we have to use a memref type with a fully dynamic layout map to
294  // avoid copies. We are currently missing patterns for layout maps to
295  // canonicalize away (or canonicalize to more precise layouts).
296  SmallVector<Operation *> worklist;
297  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
298  if (options.isOpAllowed(op) && hasTensorSemantics(op))
299  worklist.push_back(op);
300  });
301 
302  // Keep track of all erased ops.
303  DenseSet<Operation *> erasedOps;
304 
305  // Bufferize all ops.
306  BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
307  worklist, options, statistics);
308  for (unsigned i = 0; i < worklist.size(); ++i) {
309  Operation *nextOp = worklist[i];
310  // Skip ops that were erased.
311  if (erasedOps.contains(nextOp))
312  continue;
313  // Skip ops that are not bufferizable or not allowed.
314  auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
315  if (!bufferizableOp)
316  continue;
317  // Skip ops that no longer have tensor semantics.
318  if (!hasTensorSemantics(nextOp))
319  continue;
320  // Check for unsupported unstructured control flow.
321  if (!bufferizableOp.supportsUnstructuredControlFlow())
322  for (Region &r : nextOp->getRegions())
323  if (r.getBlocks().size() > 1)
324  return nextOp->emitOpError(
325  "op or BufferizableOpInterface implementation does not support "
326  "unstructured control flow, but at least one region has multiple "
327  "blocks");
328 
329  // Bufferize the op.
330  LLVM_DEBUG(llvm::dbgs()
331  << "//===-------------------------------------------===//\n"
332  << "IR after bufferizing: " << nextOp->getName() << "\n");
333  rewriter.setInsertionPoint(nextOp);
334  if (failed(bufferizableOp.bufferize(rewriter, options))) {
335  LLVM_DEBUG(llvm::dbgs()
336  << "failed to bufferize\n"
337  << "//===-------------------------------------------===//\n");
338  return nextOp->emitError("failed to bufferize op");
339  }
340  LLVM_DEBUG(llvm::dbgs()
341  << *op
342  << "\n//===-------------------------------------------===//\n");
343  }
344 
345  // Return early if the top-level op is entirely gone.
346  if (erasedOps.contains(op))
347  return success();
348 
349  // Fold all to_memref(to_tensor(x)) pairs.
350  for (Operation *op : toMemrefOps) {
351  rewriter.setInsertionPoint(op);
353  rewriter, cast<ToMemrefOp>(op), options);
354  }
355 
356  // Remove all dead to_tensor ops.
357  op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
358  if (toTensorOp->getUses().empty()) {
359  rewriter.eraseOp(toTensorOp);
360  return WalkResult::skip();
361  }
362  return WalkResult::advance();
363  });
364 
365  /// Check the result of bufferization. Return an error if an op was not
366  /// bufferized, unless partial bufferization is allowed.
367  if (options.allowUnknownOps)
368  return success();
369 
370  for (Operation *op : worklist) {
371  // Skip ops that are entirely gone.
372  if (erasedOps.contains(op))
373  continue;
374  // Ops that no longer have tensor semantics (because they were updated
375  // in-place) are allowed.
376  if (!hasTensorSemantics(op))
377  continue;
378  // Continue ops that are not allowed.
379  if (!options.isOpAllowed(op))
380  continue;
381  // Ops without any uses and no side effects will fold away.
382  if (op->getUses().empty() && isMemoryEffectFree(op))
383  continue;
384  // ToTensorOps/ToMemrefOps are allowed in the output.
385  if (isa<ToTensorOp, ToMemrefOp>(op))
386  continue;
387  return op->emitError("op was not bufferized");
388  }
389 
390  return success();
391 }
392 
393 LogicalResult
395  const BufferizationOptions &options) {
396  OpBuilder::InsertionGuard g(rewriter);
397  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
398  if (!bufferizableOp)
399  return failure();
400 
401  // Compute the new signature.
402  SmallVector<Type> newTypes;
403  for (BlockArgument &bbArg : block->getArguments()) {
404  auto tensorType = dyn_cast<TensorType>(bbArg.getType());
405  if (!tensorType) {
406  newTypes.push_back(bbArg.getType());
407  continue;
408  }
409 
410  FailureOr<BaseMemRefType> memrefType =
412  if (failed(memrefType))
413  return failure();
414  newTypes.push_back(*memrefType);
415  }
416 
417  // Change the type of all block arguments.
418  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
419  if (bbArg.getType() == type)
420  continue;
421 
422  // Collect all uses of the bbArg.
423  SmallVector<OpOperand *> bbArgUses;
424  for (OpOperand &use : bbArg.getUses())
425  bbArgUses.push_back(&use);
426 
427  Type tensorType = bbArg.getType();
428  // Change the bbArg type to memref.
429  bbArg.setType(type);
430 
431  // Replace all uses of the original tensor bbArg.
432  rewriter.setInsertionPointToStart(block);
433  if (!bbArgUses.empty()) {
434  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
435  bbArg.getLoc(), tensorType, bbArg);
436  for (OpOperand *use : bbArgUses)
437  use->set(toTensorOp);
438  }
439  }
440 
441  // Bufferize callers of the block.
442  for (Operation *op : block->getUsers()) {
443  auto branchOp = dyn_cast<BranchOpInterface>(op);
444  if (!branchOp)
445  return op->emitOpError("cannot bufferize ops with block references that "
446  "do not implement BranchOpInterface");
447 
448  auto it = llvm::find(op->getSuccessors(), block);
449  assert(it != op->getSuccessors().end() && "could find successor");
450  int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
451 
452  SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
453  SmallVector<Value> newOperands;
454  for (auto [operand, type] :
455  llvm::zip(operands.getForwardedOperands(), newTypes)) {
456  if (operand.getType() == type) {
457  // Not a tensor type. Nothing to do for this operand.
458  newOperands.push_back(operand);
459  continue;
460  }
461  FailureOr<BaseMemRefType> operandBufferType =
463  if (failed(operandBufferType))
464  return failure();
465  rewriter.setInsertionPointAfterValue(operand);
466  Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
467  operand.getLoc(), *operandBufferType, operand);
468  // A cast is needed if the operand and the block argument have different
469  // bufferized types.
470  if (type != *operandBufferType)
471  bufferizedOperand = rewriter.create<memref::CastOp>(
472  operand.getLoc(), type, bufferizedOperand);
473  newOperands.push_back(bufferizedOperand);
474  }
475  operands.getMutableForwardedOperands().assign(newOperands);
476  }
477 
478  return success();
479 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:243
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
SuccessorRange getSuccessors()
Definition: Operation.h:704
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:847
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
AnalysisState provides a variety of helper functions for dealing with tensor values.
void allowOperation()
Allow the given ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:276
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:394
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref, const BufferizationOptions &options)
Try to fold to_memref(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.