MLIR  21.0.0git
Bufferize.cpp
Go to the documentation of this file.
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/PassManager.h"
23 #include <optional>
24 
25 namespace mlir {
26 namespace bufferization {
27 #define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
28 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
29 } // namespace bufferization
30 } // namespace mlir
31 
32 #define DEBUG_TYPE "bufferize"
33 
34 using namespace mlir;
35 using namespace mlir::bufferization;
36 
37 namespace {
38 
40 parseHeuristicOption(const std::string &s) {
41  if (s == "bottom-up")
43  if (s == "top-down")
45  if (s == "bottom-up-from-terminators")
48  if (s == "fuzzer")
50  llvm_unreachable("invalid analysisheuristic option");
51 }
52 
53 struct OneShotBufferizePass
54  : public bufferization::impl::OneShotBufferizePassBase<
55  OneShotBufferizePass> {
56  using Base::Base;
57 
58  void runOnOperation() override {
60  if (!options) {
61  // Make new bufferization options if none were provided when creating the
62  // pass.
63  opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
64  opt.allowUnknownOps = allowUnknownOps;
65  opt.analysisFuzzerSeed = analysisFuzzerSeed;
66  opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
67  opt.copyBeforeWrite = copyBeforeWrite;
68  opt.dumpAliasSets = dumpAliasSets;
69  opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
70 
71  if (mustInferMemorySpace && useEncodingForMemorySpace) {
72  emitError(getOperation()->getLoc())
73  << "only one of 'must-infer-memory-space' and "
74  "'use-encoding-for-memory-space' are allowed in "
75  << getArgument();
76  return signalPassFailure();
77  }
78 
79  if (mustInferMemorySpace) {
81  [](TensorType t) -> std::optional<Attribute> {
82  return std::nullopt;
83  };
84  }
85 
86  if (useEncodingForMemorySpace) {
88  [](TensorType t) -> std::optional<Attribute> {
89  if (auto rtt = dyn_cast<RankedTensorType>(t))
90  return rtt.getEncoding();
91  return std::nullopt;
92  };
93  }
94 
95  opt.printConflicts = printConflicts;
96  opt.bufferAlignment = bufferAlignment;
97  opt.testAnalysisOnly = testAnalysisOnly;
98  opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
99  opt.checkParallelRegions = checkParallelRegions;
100  opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
101 
102  // Configure type converter.
103  LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
104  if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
106  "Invalid option: 'infer-layout-map' is not a valid value for "
107  "'unknown-type-conversion'");
108  return signalPassFailure();
109  }
110  opt.unknownTypeConverterFn = [=](TensorType tensorType,
111  Attribute memorySpace,
112  const BufferizationOptions &options) {
113  if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
115  tensorType, memorySpace);
116  assert(unknownTypeConversionOption ==
117  LayoutMapOption::FullyDynamicLayoutMap &&
118  "invalid layout map option");
120  memorySpace);
121  };
122 
123  // Configure op filter.
124  OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
125  // Filter may be specified via options.
126  if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
127  return llvm::is_contained(this->dialectFilter,
128  op->getDialect()->getNamespace());
129  // No filter specified: All other ops are allowed.
130  return true;
131  };
132  opt.opFilter.allowOperation(filterFn);
133  } else {
134  opt = *options;
135  }
136 
137  if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
138  // These two flags do not make sense together: "copy-before-write"
139  // indicates that copies should be inserted before every memory write,
140  // but "test-analysis-only" indicates that only the analysis should be
141  // tested. (I.e., no IR is bufferized.)
143  "Invalid option: 'copy-before-write' cannot be used with "
144  "'test-analysis-only'");
145  return signalPassFailure();
146  }
147 
148  if (opt.printConflicts && !opt.testAnalysisOnly) {
149  emitError(
151  "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
152  return signalPassFailure();
153  }
154 
155  if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
156  emitError(
158  "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
159  return signalPassFailure();
160  }
161 
162  BufferizationState state;
163  BufferizationStatistics statistics;
164  ModuleOp moduleOp = getOperation();
165  if (opt.bufferizeFunctionBoundaries) {
166  if (failed(
167  runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
168  signalPassFailure();
169  return;
170  }
171  } else {
172  if (!opt.noAnalysisFuncFilter.empty()) {
174  "Invalid option: 'no-analysis-func-filter' requires "
175  "'bufferize-function-boundaries'");
176  return signalPassFailure();
177  }
178  if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
179  signalPassFailure();
180  return;
181  }
182  }
183 
184  // Set pass statistics.
185  this->numBufferAlloc = statistics.numBufferAlloc;
186  this->numTensorInPlace = statistics.numTensorInPlace;
187  this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
188  }
189 
190 private:
191  std::optional<OneShotBufferizationOptions> options;
192 };
193 } // namespace
194 
195 //===----------------------------------------------------------------------===//
196 // BufferizableOpInterface-based Bufferization
197 //===----------------------------------------------------------------------===//
198 
199 namespace {
200 /// A rewriter that keeps track of extra information during bufferization.
201 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
202 public:
203  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
204  DenseSet<Operation *> &toBufferOps,
205  SmallVector<Operation *> &worklist,
207  BufferizationStatistics *statistics)
208  : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
209  worklist(worklist), analysisState(options), statistics(statistics) {
210  setListener(this);
211  }
212 
213 protected:
214  void notifyOperationErased(Operation *op) override {
215  erasedOps.insert(op);
216  // Erase if present.
217  toBufferOps.erase(op);
218  }
219 
220  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
221  // We only care about newly created ops.
222  if (previous.isSet())
223  return;
224 
225  erasedOps.erase(op);
226 
227  // Gather statistics about allocs.
228  if (statistics) {
229  if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
230  statistics->numBufferAlloc += static_cast<int64_t>(
231  sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
232  }
233 
234  // Keep track of to_buffer ops.
235  if (isa<ToBufferOp>(op)) {
236  toBufferOps.insert(op);
237  return;
238  }
239 
240  // Skip to_tensor ops.
241  if (isa<ToTensorOp>(op))
242  return;
243 
244  // Skip non-tensor ops.
245  if (!hasTensorSemantics(op))
246  return;
247 
248  // Skip ops that are not allowed to be bufferized.
249  auto const &options = analysisState.getOptions();
250  if (!options.isOpAllowed(op))
251  return;
252 
253  // Add op to worklist.
254  worklist.push_back(op);
255  }
256 
257 private:
258  /// A set of all erased ops.
259  DenseSet<Operation *> &erasedOps;
260 
261  /// A set of all to_buffer ops.
262  DenseSet<Operation *> &toBufferOps;
263 
264  /// The worklist of ops to be bufferized.
265  SmallVector<Operation *> &worklist;
266 
267  /// The analysis state. Used for debug assertions and access to the
268  /// bufferization options.
269  const AnalysisState analysisState;
270 
271  /// Bufferization statistics for debugging.
272  BufferizationStatistics *statistics;
273 };
274 } // namespace
275 
278  BufferizationState &bufferizationState,
279  BufferizationStatistics *statistics) {
280  if (options.copyBeforeWrite) {
281  AnalysisState analysisState(options);
282  if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
283  return failure();
284  }
285 
286  // Keep track of to_buffer ops.
287  DenseSet<Operation *> toBufferOps;
288  op->walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
289 
290  // Gather all bufferizable ops in top-to-bottom order.
291  //
292  // We should ideally know the exact memref type of all operands when
293  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
294  // Otherwise, we have to use a memref type with a fully dynamic layout map to
295  // avoid copies. We are currently missing patterns for layout maps to
296  // canonicalize away (or canonicalize to more precise layouts).
297  SmallVector<Operation *> worklist;
298  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
299  if (options.isOpAllowed(op) && hasTensorSemantics(op))
300  worklist.push_back(op);
301  });
302 
303  // Keep track of all erased ops.
304  DenseSet<Operation *> erasedOps;
305 
306  // Bufferize all ops.
307  BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
308  worklist, options, statistics);
309  for (unsigned i = 0; i < worklist.size(); ++i) {
310  Operation *nextOp = worklist[i];
311  // Skip ops that were erased.
312  if (erasedOps.contains(nextOp))
313  continue;
314  // Skip ops that are not bufferizable or not allowed.
315  auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
316  if (!bufferizableOp)
317  continue;
318  // Skip ops that no longer have tensor semantics.
319  if (!hasTensorSemantics(nextOp))
320  continue;
321  // Check for unsupported unstructured control flow.
322  if (!bufferizableOp.supportsUnstructuredControlFlow())
323  for (Region &r : nextOp->getRegions())
324  if (r.getBlocks().size() > 1)
325  return nextOp->emitOpError(
326  "op or BufferizableOpInterface implementation does not support "
327  "unstructured control flow, but at least one region has multiple "
328  "blocks");
329 
330  // Bufferize the op.
331  LLVM_DEBUG(llvm::dbgs()
332  << "//===-------------------------------------------===//\n"
333  << "IR after bufferizing: " << nextOp->getName() << "\n");
334  rewriter.setInsertionPoint(nextOp);
335  if (failed(
336  bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
337  LLVM_DEBUG(llvm::dbgs()
338  << "failed to bufferize\n"
339  << "//===-------------------------------------------===//\n");
340  return nextOp->emitError("failed to bufferize op");
341  }
342  LLVM_DEBUG(llvm::dbgs()
343  << *op
344  << "\n//===-------------------------------------------===//\n");
345  }
346 
347  // Return early if the top-level op is entirely gone.
348  if (erasedOps.contains(op))
349  return success();
350 
351  // Fold all to_buffer(to_tensor(x)) pairs.
352  for (Operation *op : toBufferOps) {
353  rewriter.setInsertionPoint(op);
355  rewriter, cast<ToBufferOp>(op), options);
356  }
357 
358  // Remove all dead to_tensor ops.
359  op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
360  if (toTensorOp->getUses().empty()) {
361  rewriter.eraseOp(toTensorOp);
362  return WalkResult::skip();
363  }
364  return WalkResult::advance();
365  });
366 
367  /// Check the result of bufferization. Return an error if an op was not
368  /// bufferized, unless partial bufferization is allowed.
369  if (options.allowUnknownOps)
370  return success();
371 
372  for (Operation *op : worklist) {
373  // Skip ops that are entirely gone.
374  if (erasedOps.contains(op))
375  continue;
376  // Ops that no longer have tensor semantics (because they were updated
377  // in-place) are allowed.
378  if (!hasTensorSemantics(op))
379  continue;
380  // Continue ops that are not allowed.
381  if (!options.isOpAllowed(op))
382  continue;
383  // Ops without any uses and no side effects will fold away.
384  if (op->getUses().empty() && isMemoryEffectFree(op))
385  continue;
386  // ToTensorOps/ToBufferOps are allowed in the output.
387  if (isa<ToTensorOp, ToBufferOp>(op))
388  continue;
389  return op->emitError("op was not bufferized");
390  }
391 
392  return success();
393 }
394 
395 LogicalResult
398  BufferizationState &state) {
399  OpBuilder::InsertionGuard g(rewriter);
400  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
401  if (!bufferizableOp)
402  return failure();
403 
404  // Compute the new signature.
405  SmallVector<Type> newTypes;
406  for (BlockArgument &bbArg : block->getArguments()) {
407  auto tensorType = dyn_cast<TensorType>(bbArg.getType());
408  if (!tensorType) {
409  newTypes.push_back(bbArg.getType());
410  continue;
411  }
412 
413  FailureOr<BufferLikeType> bufferType =
414  bufferization::getBufferType(bbArg, options, state);
415  if (failed(bufferType))
416  return failure();
417  newTypes.push_back(*bufferType);
418  }
419 
420  // Change the type of all block arguments.
421  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
422  if (bbArg.getType() == type)
423  continue;
424 
425  // Collect all uses of the bbArg.
426  SmallVector<OpOperand *> bbArgUses;
427  for (OpOperand &use : bbArg.getUses())
428  bbArgUses.push_back(&use);
429 
430  Type tensorType = bbArg.getType();
431  // Change the bbArg type to memref.
432  bbArg.setType(type);
433 
434  // Replace all uses of the original tensor bbArg.
435  rewriter.setInsertionPointToStart(block);
436  if (!bbArgUses.empty()) {
437  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
438  bbArg.getLoc(), tensorType, bbArg);
439  for (OpOperand *use : bbArgUses)
440  use->set(toTensorOp);
441  }
442  }
443 
444  // Bufferize callers of the block.
445  for (Operation *op : block->getUsers()) {
446  auto branchOp = dyn_cast<BranchOpInterface>(op);
447  if (!branchOp)
448  return op->emitOpError("cannot bufferize ops with block references that "
449  "do not implement BranchOpInterface");
450 
451  auto it = llvm::find(op->getSuccessors(), block);
452  assert(it != op->getSuccessors().end() && "could find successor");
453  int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
454 
455  SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
456  SmallVector<Value> newOperands;
457  for (auto [operand, type] :
458  llvm::zip(operands.getForwardedOperands(), newTypes)) {
459  if (operand.getType() == type) {
460  // Not a tensor type. Nothing to do for this operand.
461  newOperands.push_back(operand);
462  continue;
463  }
464  FailureOr<BufferLikeType> operandBufferType =
465  bufferization::getBufferType(operand, options, state);
466  if (failed(operandBufferType))
467  return failure();
468  rewriter.setInsertionPointAfterValue(operand);
469  Value bufferizedOperand = rewriter.create<bufferization::ToBufferOp>(
470  operand.getLoc(), *operandBufferType, operand);
471  // A cast is needed if the operand and the block argument have different
472  // bufferized types.
473  if (type != *operandBufferType)
474  bufferizedOperand = rewriter.create<memref::CastOp>(
475  operand.getLoc(), type, bufferizedOperand);
476  newOperands.push_back(bufferizedOperand);
477  }
478  operands.getMutableForwardedOperands().assign(newOperands);
479  }
480 
481  return success();
482 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
SuccessorRange getSuccessors()
Definition: Operation.h:703
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
AnalysisState provides a variety of helper functions for dealing with tensor values.
BufferizationState provides information about the state of the IR during the bufferization process.
void allowOperation()
Allow the given ops.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:276
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:396
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
Definition: Bufferize.h:35
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.