MLIR  22.0.0git
Bufferize.cpp
Go to the documentation of this file.
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "llvm/Support/DebugLog.h"
24 #include <optional>
25 
26 namespace mlir {
27 namespace bufferization {
28 #define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
29 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
30 } // namespace bufferization
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "bufferize"
34 
35 using namespace mlir;
36 using namespace mlir::bufferization;
37 
38 namespace {
39 
41 parseHeuristicOption(const std::string &s) {
42  if (s == "bottom-up")
44  if (s == "top-down")
46  if (s == "bottom-up-from-terminators")
49  if (s == "fuzzer")
51  llvm_unreachable("invalid analysisheuristic option");
52 }
53 
54 struct OneShotBufferizePass
55  : public bufferization::impl::OneShotBufferizePassBase<
56  OneShotBufferizePass> {
57  using Base::Base;
58 
59  void runOnOperation() override {
61  if (!options) {
62  // Make new bufferization options if none were provided when creating the
63  // pass.
64  opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
65  opt.allowUnknownOps = allowUnknownOps;
66  opt.analysisFuzzerSeed = analysisFuzzerSeed;
67  opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
68  opt.copyBeforeWrite = copyBeforeWrite;
69  opt.dumpAliasSets = dumpAliasSets;
70  opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
71 
72  if (mustInferMemorySpace && useEncodingForMemorySpace) {
73  emitError(getOperation()->getLoc())
74  << "only one of 'must-infer-memory-space' and "
75  "'use-encoding-for-memory-space' are allowed in "
76  << getArgument();
77  return signalPassFailure();
78  }
79 
80  if (mustInferMemorySpace) {
82  [](TensorType t) -> std::optional<Attribute> {
83  return std::nullopt;
84  };
85  }
86 
87  if (useEncodingForMemorySpace) {
89  [](TensorType t) -> std::optional<Attribute> {
90  if (auto rtt = dyn_cast<RankedTensorType>(t))
91  return rtt.getEncoding();
92  return std::nullopt;
93  };
94  }
95 
96  opt.printConflicts = printConflicts;
97  opt.bufferAlignment = bufferAlignment;
98  opt.testAnalysisOnly = testAnalysisOnly;
99  opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
100  opt.checkParallelRegions = checkParallelRegions;
101  opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
102 
103  // Configure type converter.
104  LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
105  if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
107  "Invalid option: 'infer-layout-map' is not a valid value for "
108  "'unknown-type-conversion'");
109  return signalPassFailure();
110  }
111  opt.unknownTypeConverterFn = [=](TensorType tensorType,
112  Attribute memorySpace,
113  const BufferizationOptions &options) {
114  if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116  tensorType, memorySpace);
117  assert(unknownTypeConversionOption ==
118  LayoutMapOption::FullyDynamicLayoutMap &&
119  "invalid layout map option");
121  memorySpace);
122  };
123 
124  // Configure op filter.
125  OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
126  // Filter may be specified via options.
127  if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
128  return llvm::is_contained(this->dialectFilter,
129  op->getDialect()->getNamespace());
130  // No filter specified: All other ops are allowed.
131  return true;
132  };
133  opt.opFilter.allowOperation(filterFn);
134  } else {
135  opt = *options;
136  }
137 
138  if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
139  // These two flags do not make sense together: "copy-before-write"
140  // indicates that copies should be inserted before every memory write,
141  // but "test-analysis-only" indicates that only the analysis should be
142  // tested. (I.e., no IR is bufferized.)
144  "Invalid option: 'copy-before-write' cannot be used with "
145  "'test-analysis-only'");
146  return signalPassFailure();
147  }
148 
149  if (opt.printConflicts && !opt.testAnalysisOnly) {
150  emitError(
152  "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
153  return signalPassFailure();
154  }
155 
156  if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
157  emitError(
159  "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
160  return signalPassFailure();
161  }
162 
163  BufferizationState state;
164  BufferizationStatistics statistics;
165  ModuleOp moduleOp = getOperation();
166  if (opt.bufferizeFunctionBoundaries) {
167  if (failed(
168  runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
169  signalPassFailure();
170  return;
171  }
172  } else {
173  if (!opt.noAnalysisFuncFilter.empty()) {
175  "Invalid option: 'no-analysis-func-filter' requires "
176  "'bufferize-function-boundaries'");
177  return signalPassFailure();
178  }
179  if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
180  signalPassFailure();
181  return;
182  }
183  }
184 
185  // Set pass statistics.
186  this->numBufferAlloc = statistics.numBufferAlloc;
187  this->numTensorInPlace = statistics.numTensorInPlace;
188  this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
189  }
190 
191 private:
192  std::optional<OneShotBufferizationOptions> options;
193 };
194 } // namespace
195 
196 //===----------------------------------------------------------------------===//
197 // BufferizableOpInterface-based Bufferization
198 //===----------------------------------------------------------------------===//
199 
200 namespace {
201 /// A rewriter that keeps track of extra information during bufferization.
202 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
203 public:
204  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
205  DenseSet<Operation *> &toBufferOps,
206  SmallVector<Operation *> &worklist,
208  BufferizationStatistics *statistics)
209  : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
210  worklist(worklist), analysisState(options), statistics(statistics) {
211  setListener(this);
212  }
213 
214 protected:
215  void notifyOperationErased(Operation *op) override {
216  erasedOps.insert(op);
217  // Erase if present.
218  toBufferOps.erase(op);
219  }
220 
221  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
222  // We only care about newly created ops.
223  if (previous.isSet())
224  return;
225 
226  erasedOps.erase(op);
227 
228  // Gather statistics about allocs.
229  if (statistics) {
230  if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
231  statistics->numBufferAlloc += static_cast<int64_t>(
232  sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
233  }
234 
235  // Keep track of to_buffer ops.
236  if (isa<ToBufferOp>(op)) {
237  toBufferOps.insert(op);
238  return;
239  }
240 
241  // Skip to_tensor ops.
242  if (isa<ToTensorOp>(op))
243  return;
244 
245  // Skip non-tensor ops.
246  if (!hasTensorSemantics(op))
247  return;
248 
249  // Skip ops that are not allowed to be bufferized.
250  auto const &options = analysisState.getOptions();
251  if (!options.isOpAllowed(op))
252  return;
253 
254  // Add op to worklist.
255  worklist.push_back(op);
256  }
257 
258 private:
259  /// A set of all erased ops.
260  DenseSet<Operation *> &erasedOps;
261 
262  /// A set of all to_buffer ops.
263  DenseSet<Operation *> &toBufferOps;
264 
265  /// The worklist of ops to be bufferized.
266  SmallVector<Operation *> &worklist;
267 
268  /// The analysis state. Used for debug assertions and access to the
269  /// bufferization options.
270  const AnalysisState analysisState;
271 
272  /// Bufferization statistics for debugging.
273  BufferizationStatistics *statistics;
274 };
275 } // namespace
276 
279  BufferizationState &bufferizationState,
280  BufferizationStatistics *statistics) {
281  if (options.copyBeforeWrite) {
282  AnalysisState analysisState(options);
283  if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
284  return failure();
285  }
286 
287  // Keep track of to_buffer ops.
288  DenseSet<Operation *> toBufferOps;
289  op->walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
290 
291  // Gather all bufferizable ops in top-to-bottom order.
292  //
293  // We should ideally know the exact memref type of all operands when
294  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
295  // Otherwise, we have to use a memref type with a fully dynamic layout map to
296  // avoid copies. We are currently missing patterns for layout maps to
297  // canonicalize away (or canonicalize to more precise layouts).
298  SmallVector<Operation *> worklist;
299  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
300  if (options.isOpAllowed(op) && hasTensorSemantics(op))
301  worklist.push_back(op);
302  });
303 
304  // Keep track of all erased ops.
305  DenseSet<Operation *> erasedOps;
306 
307  // Bufferize all ops.
308  BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
309  worklist, options, statistics);
310  for (unsigned i = 0; i < worklist.size(); ++i) {
311  Operation *nextOp = worklist[i];
312  // Skip ops that were erased.
313  if (erasedOps.contains(nextOp))
314  continue;
315  // Skip ops that are not bufferizable or not allowed.
316  auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
317  if (!bufferizableOp)
318  continue;
319  // Skip ops that no longer have tensor semantics.
320  if (!hasTensorSemantics(nextOp))
321  continue;
322  // Check for unsupported unstructured control flow.
323  if (!bufferizableOp.supportsUnstructuredControlFlow())
324  for (Region &r : nextOp->getRegions())
325  if (r.getBlocks().size() > 1)
326  return nextOp->emitOpError(
327  "op or BufferizableOpInterface implementation does not support "
328  "unstructured control flow, but at least one region has multiple "
329  "blocks");
330 
331  // Bufferize the op.
332  LDBG(3) << "//===-------------------------------------------===//\n"
333  << "IR after bufferizing: " << nextOp->getName();
334  rewriter.setInsertionPoint(nextOp);
335  if (failed(
336  bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
337  LDBG(2) << "failed to bufferize\n"
338  << "//===-------------------------------------------===//";
339  return nextOp->emitError("failed to bufferize op");
340  }
341  LDBG(3) << *op << "\n//===-------------------------------------------===//";
342  }
343 
344  // Return early if the top-level op is entirely gone.
345  if (erasedOps.contains(op))
346  return success();
347 
348  // Fold all to_buffer(to_tensor(x)) pairs.
349  for (Operation *op : toBufferOps) {
350  rewriter.setInsertionPoint(op);
352  rewriter, cast<ToBufferOp>(op), options);
353  }
354 
355  // Remove all dead to_tensor ops.
356  op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
357  if (toTensorOp->getUses().empty()) {
358  rewriter.eraseOp(toTensorOp);
359  return WalkResult::skip();
360  }
361  return WalkResult::advance();
362  });
363 
364  /// Check the result of bufferization. Return an error if an op was not
365  /// bufferized, unless partial bufferization is allowed.
366  if (options.allowUnknownOps)
367  return success();
368 
369  for (Operation *op : worklist) {
370  // Skip ops that are entirely gone.
371  if (erasedOps.contains(op))
372  continue;
373  // Ops that no longer have tensor semantics (because they were updated
374  // in-place) are allowed.
375  if (!hasTensorSemantics(op))
376  continue;
377  // Continue ops that are not allowed.
378  if (!options.isOpAllowed(op))
379  continue;
380  // Ops without any uses and no side effects will fold away.
381  if (op->getUses().empty() && isMemoryEffectFree(op))
382  continue;
383  // ToTensorOps/ToBufferOps are allowed in the output.
384  if (isa<ToTensorOp, ToBufferOp>(op))
385  continue;
386  return op->emitError("op was not bufferized");
387  }
388 
389  return success();
390 }
391 
392 LogicalResult
395  BufferizationState &state) {
396  OpBuilder::InsertionGuard g(rewriter);
397  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
398  if (!bufferizableOp)
399  return failure();
400 
401  // Compute the new signature.
402  SmallVector<Type> newTypes;
403  for (BlockArgument &bbArg : block->getArguments()) {
404  auto tensorType = dyn_cast<TensorType>(bbArg.getType());
405  if (!tensorType) {
406  newTypes.push_back(bbArg.getType());
407  continue;
408  }
409 
410  FailureOr<BufferLikeType> bufferType =
411  bufferization::getBufferType(bbArg, options, state);
412  if (failed(bufferType))
413  return failure();
414  newTypes.push_back(*bufferType);
415  }
416 
417  // Change the type of all block arguments.
418  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
419  if (bbArg.getType() == type)
420  continue;
421 
422  // Collect all uses of the bbArg.
423  SmallVector<OpOperand *> bbArgUses;
424  for (OpOperand &use : bbArg.getUses())
425  bbArgUses.push_back(&use);
426 
427  Type tensorType = bbArg.getType();
428  // Change the bbArg type to memref.
429  bbArg.setType(type);
430 
431  // Replace all uses of the original tensor bbArg.
432  rewriter.setInsertionPointToStart(block);
433  if (!bbArgUses.empty()) {
434  Value toTensorOp = bufferization::ToTensorOp::create(
435  rewriter, bbArg.getLoc(), tensorType, bbArg);
436  for (OpOperand *use : bbArgUses)
437  use->set(toTensorOp);
438  }
439  }
440 
441  // Bufferize callers of the block.
442  for (Operation *op : block->getUsers()) {
443  auto branchOp = dyn_cast<BranchOpInterface>(op);
444  if (!branchOp)
445  return op->emitOpError("cannot bufferize ops with block references that "
446  "do not implement BranchOpInterface");
447 
448  auto it = llvm::find(op->getSuccessors(), block);
449  assert(it != op->getSuccessors().end() && "could find successor");
450  int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
451 
452  SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
453  SmallVector<Value> newOperands;
454  for (auto [operand, type] :
455  llvm::zip(operands.getForwardedOperands(), newTypes)) {
456  if (operand.getType() == type) {
457  // Not a tensor type. Nothing to do for this operand.
458  newOperands.push_back(operand);
459  continue;
460  }
461  FailureOr<BufferLikeType> operandBufferType =
462  bufferization::getBufferType(operand, options, state);
463  if (failed(operandBufferType))
464  return failure();
465  rewriter.setInsertionPointAfterValue(operand);
466  Value bufferizedOperand = bufferization::ToBufferOp::create(
467  rewriter, operand.getLoc(), *operandBufferType, operand);
468  // A cast is needed if the operand and the block argument have different
469  // bufferized types.
470  if (type != *operandBufferType)
471  bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(),
472  type, bufferizedOperand);
473  newOperands.push_back(bufferizedOperand);
474  }
475  operands.getMutableForwardedOperands().assign(newOperands);
476  }
477 
478  return success();
479 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:421
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
SuccessorRange getSuccessors()
Definition: Operation.h:703
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
AnalysisState provides a variety of helper functions for dealing with tensor values.
BufferizationState provides information about the state of the IR during the bufferization process.
void allowOperation()
Allow the given ops.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:277
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:393
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
Definition: Bufferize.h:35
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.