MLIR  21.0.0git
Bufferize.cpp
Go to the documentation of this file.
1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Operation.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/Passes.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace bufferization {
29 #define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferize"
35 
36 using namespace mlir;
37 using namespace mlir::bufferization;
38 
39 namespace {
40 
42 parseHeuristicOption(const std::string &s) {
43  if (s == "bottom-up")
45  if (s == "top-down")
47  if (s == "bottom-up-from-terminators")
50  if (s == "fuzzer")
52  llvm_unreachable("invalid analysisheuristic option");
53 }
54 
55 struct OneShotBufferizePass
56  : public bufferization::impl::OneShotBufferizePassBase<
57  OneShotBufferizePass> {
58  using Base::Base;
59 
60  void runOnOperation() override {
62  if (!options) {
63  // Make new bufferization options if none were provided when creating the
64  // pass.
65  opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
66  opt.allowUnknownOps = allowUnknownOps;
67  opt.analysisFuzzerSeed = analysisFuzzerSeed;
68  opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
69  opt.copyBeforeWrite = copyBeforeWrite;
70  opt.dumpAliasSets = dumpAliasSets;
71  opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
72 
73  if (mustInferMemorySpace && useEncodingForMemorySpace) {
74  emitError(getOperation()->getLoc())
75  << "only one of 'must-infer-memory-space' and "
76  "'use-encoding-for-memory-space' are allowed in "
77  << getArgument();
78  return signalPassFailure();
79  }
80 
81  if (mustInferMemorySpace) {
83  [](TensorType t) -> std::optional<Attribute> {
84  return std::nullopt;
85  };
86  }
87 
88  if (useEncodingForMemorySpace) {
90  [](TensorType t) -> std::optional<Attribute> {
91  if (auto rtt = dyn_cast<RankedTensorType>(t))
92  return rtt.getEncoding();
93  return std::nullopt;
94  };
95  }
96 
97  opt.printConflicts = printConflicts;
98  opt.bufferAlignment = bufferAlignment;
99  opt.testAnalysisOnly = testAnalysisOnly;
100  opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
101  opt.checkParallelRegions = checkParallelRegions;
102  opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
103 
104  // Configure type converter.
105  LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
106  if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
108  "Invalid option: 'infer-layout-map' is not a valid value for "
109  "'unknown-type-conversion'");
110  return signalPassFailure();
111  }
112  opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
113  const BufferizationOptions &options) {
114  auto tensorType = cast<TensorType>(value.getType());
115  if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
117  tensorType, memorySpace);
118  assert(unknownTypeConversionOption ==
119  LayoutMapOption::FullyDynamicLayoutMap &&
120  "invalid layout map option");
122  memorySpace);
123  };
124 
125  // Configure op filter.
126  OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
127  // Filter may be specified via options.
128  if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
129  return llvm::is_contained(this->dialectFilter,
130  op->getDialect()->getNamespace());
131  // No filter specified: All other ops are allowed.
132  return true;
133  };
134  opt.opFilter.allowOperation(filterFn);
135  } else {
136  opt = *options;
137  }
138 
139  if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
140  // These two flags do not make sense together: "copy-before-write"
141  // indicates that copies should be inserted before every memory write,
142  // but "test-analysis-only" indicates that only the analysis should be
143  // tested. (I.e., no IR is bufferized.)
145  "Invalid option: 'copy-before-write' cannot be used with "
146  "'test-analysis-only'");
147  return signalPassFailure();
148  }
149 
150  if (opt.printConflicts && !opt.testAnalysisOnly) {
151  emitError(
153  "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
154  return signalPassFailure();
155  }
156 
157  if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
158  emitError(
160  "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
161  return signalPassFailure();
162  }
163 
164  BufferizationState state;
165  BufferizationStatistics statistics;
166  ModuleOp moduleOp = getOperation();
167  if (opt.bufferizeFunctionBoundaries) {
168  if (failed(
169  runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
170  signalPassFailure();
171  return;
172  }
173  } else {
174  if (!opt.noAnalysisFuncFilter.empty()) {
176  "Invalid option: 'no-analysis-func-filter' requires "
177  "'bufferize-function-boundaries'");
178  return signalPassFailure();
179  }
180  if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
181  signalPassFailure();
182  return;
183  }
184  }
185 
186  // Set pass statistics.
187  this->numBufferAlloc = statistics.numBufferAlloc;
188  this->numTensorInPlace = statistics.numTensorInPlace;
189  this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
190  }
191 
192 private:
193  std::optional<OneShotBufferizationOptions> options;
194 };
195 } // namespace
196 
197 //===----------------------------------------------------------------------===//
198 // BufferizableOpInterface-based Bufferization
199 //===----------------------------------------------------------------------===//
200 
201 namespace {
202 /// A rewriter that keeps track of extra information during bufferization.
203 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
204 public:
205  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
206  DenseSet<Operation *> &toBufferOps,
207  SmallVector<Operation *> &worklist,
209  BufferizationStatistics *statistics)
210  : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
211  worklist(worklist), analysisState(options), statistics(statistics) {
212  setListener(this);
213  }
214 
215 protected:
216  void notifyOperationErased(Operation *op) override {
217  erasedOps.insert(op);
218  // Erase if present.
219  toBufferOps.erase(op);
220  }
221 
222  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
223  // We only care about newly created ops.
224  if (previous.isSet())
225  return;
226 
227  erasedOps.erase(op);
228 
229  // Gather statistics about allocs.
230  if (statistics) {
231  if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
232  statistics->numBufferAlloc += static_cast<int64_t>(
233  sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
234  }
235 
236  // Keep track of to_buffer ops.
237  if (isa<ToBufferOp>(op)) {
238  toBufferOps.insert(op);
239  return;
240  }
241 
242  // Skip to_tensor ops.
243  if (isa<ToTensorOp>(op))
244  return;
245 
246  // Skip non-tensor ops.
247  if (!hasTensorSemantics(op))
248  return;
249 
250  // Skip ops that are not allowed to be bufferized.
251  auto const &options = analysisState.getOptions();
252  if (!options.isOpAllowed(op))
253  return;
254 
255  // Add op to worklist.
256  worklist.push_back(op);
257  }
258 
259 private:
260  /// A set of all erased ops.
261  DenseSet<Operation *> &erasedOps;
262 
263  /// A set of all to_buffer ops.
264  DenseSet<Operation *> &toBufferOps;
265 
266  /// The worklist of ops to be bufferized.
267  SmallVector<Operation *> &worklist;
268 
269  /// The analysis state. Used for debug assertions and access to the
270  /// bufferization options.
271  const AnalysisState analysisState;
272 
273  /// Bufferization statistics for debugging.
274  BufferizationStatistics *statistics;
275 };
276 } // namespace
277 
280  BufferizationState &bufferizationState,
281  BufferizationStatistics *statistics) {
282  if (options.copyBeforeWrite) {
283  AnalysisState analysisState(options);
284  if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
285  return failure();
286  }
287 
288  // Keep track of to_buffer ops.
289  DenseSet<Operation *> toBufferOps;
290  op->walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
291 
292  // Gather all bufferizable ops in top-to-bottom order.
293  //
294  // We should ideally know the exact memref type of all operands when
295  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
296  // Otherwise, we have to use a memref type with a fully dynamic layout map to
297  // avoid copies. We are currently missing patterns for layout maps to
298  // canonicalize away (or canonicalize to more precise layouts).
299  SmallVector<Operation *> worklist;
300  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
301  if (options.isOpAllowed(op) && hasTensorSemantics(op))
302  worklist.push_back(op);
303  });
304 
305  // Keep track of all erased ops.
306  DenseSet<Operation *> erasedOps;
307 
308  // Bufferize all ops.
309  BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
310  worklist, options, statistics);
311  for (unsigned i = 0; i < worklist.size(); ++i) {
312  Operation *nextOp = worklist[i];
313  // Skip ops that were erased.
314  if (erasedOps.contains(nextOp))
315  continue;
316  // Skip ops that are not bufferizable or not allowed.
317  auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
318  if (!bufferizableOp)
319  continue;
320  // Skip ops that no longer have tensor semantics.
321  if (!hasTensorSemantics(nextOp))
322  continue;
323  // Check for unsupported unstructured control flow.
324  if (!bufferizableOp.supportsUnstructuredControlFlow())
325  for (Region &r : nextOp->getRegions())
326  if (r.getBlocks().size() > 1)
327  return nextOp->emitOpError(
328  "op or BufferizableOpInterface implementation does not support "
329  "unstructured control flow, but at least one region has multiple "
330  "blocks");
331 
332  // Bufferize the op.
333  LLVM_DEBUG(llvm::dbgs()
334  << "//===-------------------------------------------===//\n"
335  << "IR after bufferizing: " << nextOp->getName() << "\n");
336  rewriter.setInsertionPoint(nextOp);
337  if (failed(
338  bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
339  LLVM_DEBUG(llvm::dbgs()
340  << "failed to bufferize\n"
341  << "//===-------------------------------------------===//\n");
342  return nextOp->emitError("failed to bufferize op");
343  }
344  LLVM_DEBUG(llvm::dbgs()
345  << *op
346  << "\n//===-------------------------------------------===//\n");
347  }
348 
349  // Return early if the top-level op is entirely gone.
350  if (erasedOps.contains(op))
351  return success();
352 
353  // Fold all to_buffer(to_tensor(x)) pairs.
354  for (Operation *op : toBufferOps) {
355  rewriter.setInsertionPoint(op);
357  rewriter, cast<ToBufferOp>(op), options);
358  }
359 
360  // Remove all dead to_tensor ops.
361  op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
362  if (toTensorOp->getUses().empty()) {
363  rewriter.eraseOp(toTensorOp);
364  return WalkResult::skip();
365  }
366  return WalkResult::advance();
367  });
368 
369  /// Check the result of bufferization. Return an error if an op was not
370  /// bufferized, unless partial bufferization is allowed.
371  if (options.allowUnknownOps)
372  return success();
373 
374  for (Operation *op : worklist) {
375  // Skip ops that are entirely gone.
376  if (erasedOps.contains(op))
377  continue;
378  // Ops that no longer have tensor semantics (because they were updated
379  // in-place) are allowed.
380  if (!hasTensorSemantics(op))
381  continue;
382  // Continue ops that are not allowed.
383  if (!options.isOpAllowed(op))
384  continue;
385  // Ops without any uses and no side effects will fold away.
386  if (op->getUses().empty() && isMemoryEffectFree(op))
387  continue;
388  // ToTensorOps/ToBufferOps are allowed in the output.
389  if (isa<ToTensorOp, ToBufferOp>(op))
390  continue;
391  return op->emitError("op was not bufferized");
392  }
393 
394  return success();
395 }
396 
397 LogicalResult
400  BufferizationState &state) {
401  OpBuilder::InsertionGuard g(rewriter);
402  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
403  if (!bufferizableOp)
404  return failure();
405 
406  // Compute the new signature.
407  SmallVector<Type> newTypes;
408  for (BlockArgument &bbArg : block->getArguments()) {
409  auto tensorType = dyn_cast<TensorType>(bbArg.getType());
410  if (!tensorType) {
411  newTypes.push_back(bbArg.getType());
412  continue;
413  }
414 
415  FailureOr<BaseMemRefType> memrefType =
416  bufferization::getBufferType(bbArg, options, state);
417  if (failed(memrefType))
418  return failure();
419  newTypes.push_back(*memrefType);
420  }
421 
422  // Change the type of all block arguments.
423  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
424  if (bbArg.getType() == type)
425  continue;
426 
427  // Collect all uses of the bbArg.
428  SmallVector<OpOperand *> bbArgUses;
429  for (OpOperand &use : bbArg.getUses())
430  bbArgUses.push_back(&use);
431 
432  Type tensorType = bbArg.getType();
433  // Change the bbArg type to memref.
434  bbArg.setType(type);
435 
436  // Replace all uses of the original tensor bbArg.
437  rewriter.setInsertionPointToStart(block);
438  if (!bbArgUses.empty()) {
439  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
440  bbArg.getLoc(), tensorType, bbArg);
441  for (OpOperand *use : bbArgUses)
442  use->set(toTensorOp);
443  }
444  }
445 
446  // Bufferize callers of the block.
447  for (Operation *op : block->getUsers()) {
448  auto branchOp = dyn_cast<BranchOpInterface>(op);
449  if (!branchOp)
450  return op->emitOpError("cannot bufferize ops with block references that "
451  "do not implement BranchOpInterface");
452 
453  auto it = llvm::find(op->getSuccessors(), block);
454  assert(it != op->getSuccessors().end() && "could find successor");
455  int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
456 
457  SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
458  SmallVector<Value> newOperands;
459  for (auto [operand, type] :
460  llvm::zip(operands.getForwardedOperands(), newTypes)) {
461  if (operand.getType() == type) {
462  // Not a tensor type. Nothing to do for this operand.
463  newOperands.push_back(operand);
464  continue;
465  }
466  FailureOr<BaseMemRefType> operandBufferType =
467  bufferization::getBufferType(operand, options, state);
468  if (failed(operandBufferType))
469  return failure();
470  rewriter.setInsertionPointAfterValue(operand);
471  Value bufferizedOperand = rewriter.create<bufferization::ToBufferOp>(
472  operand.getLoc(), *operandBufferType, operand);
473  // A cast is needed if the operand and the block argument have different
474  // bufferized types.
475  if (type != *operandBufferType)
476  bufferizedOperand = rewriter.create<memref::CastOp>(
477  operand.getLoc(), type, bufferizedOperand);
478  newOperands.push_back(bufferizedOperand);
479  }
480  operands.getMutableForwardedOperands().assign(newOperands);
481  }
482 
483  return success();
484 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:418
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
SuccessorRange getSuccessors()
Definition: Operation.h:703
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
AnalysisState provides a variety of helper functions for dealing with tensor values.
BufferizationState provides information about the state of the IR during the bufferization process.
void allowOperation()
Allow the given ops.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:278
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:398
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
The following effect indicates that the operation allocates from some resource.
Options for BufferizableOpInterface-based bufferization.
bool copyBeforeWrite
If set to true, the analysis is skipped.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
UnknownTypeConverterFn unknownTypeConverterFn
Type converter from tensors to memrefs.
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
Bufferization statistics for debugging.
Definition: Bufferize.h:35
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.