MLIR  16.0.0git
TensorCopyInsertion.cpp
Go to the documentation of this file.
1 //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "PassDetail.h"
12 
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 
25  OneShotAnalysisState state(op, options);
26  // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27  // analysis depending on whether function boundary bufferization is enabled or
28  // not.
29  if (options.bufferizeFunctionBoundaries) {
30  if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31  return failure();
32  } else {
33  if (failed(analyzeOp(op, state)))
34  return failure();
35  }
36 
37  if (options.testAnalysisOnly)
38  return success();
39 
40  return insertTensorCopies(op, state);
41 }
42 
45  const AnalysisState &state) {
46  IRRewriter rewriter(op->getContext());
47  StringRef escapeAttrName = BufferizationDialect::kEscapeAttrName;
48 
49  WalkResult result = op->walk([&](Operation *op) {
50  auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
51  if (!bufferizableOp)
52  return WalkResult::skip();
53 
54  // Find allocations without an `escape` attribute and add the attribute
55  // based on analysis results.
56  if (!op->hasAttr(escapeAttrName)) {
57  SmallVector<bool> escapeAttrValue;
58  bool foundTensorResult = false;
59  for (OpResult opResult : op->getOpResults()) {
60  if (!opResult.getType().isa<TensorType>() ||
61  !bufferizableOp.bufferizesToAllocation(opResult)) {
62  escapeAttrValue.push_back(false);
63  continue;
64  }
65  foundTensorResult = true;
66  bool escape = !state.getOptions().createDeallocs ||
67  state.isTensorYielded(opResult);
68  escapeAttrValue.push_back(escape);
69  }
70  if (foundTensorResult)
71  op->setAttr(escapeAttrName, rewriter.getBoolArrayAttr(escapeAttrValue));
72  }
73 
74  // Find inplacability conflicts and resolve them. (Typically with explicit
75  // tensor copies in the form of AllocTensorOps.)
76  rewriter.setInsertionPoint(op);
77  if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
78  return WalkResult::interrupt();
79 
80  return WalkResult::advance();
81  });
82 
83  return failure(result.wasInterrupted());
84 }
85 
86 namespace {
87 struct TensorCopyInsertionPass
88  : TensorCopyInsertionBase<TensorCopyInsertionPass> {
89  TensorCopyInsertionPass()
90  : TensorCopyInsertionBase<TensorCopyInsertionPass>(),
92  TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
93  : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
94 
95  void getDependentDialects(DialectRegistry &registry) const override {
96  registry.insert<bufferization::BufferizationDialect>();
97  }
98 
99  void runOnOperation() override {
100  if (options) {
101  if (failed(insertTensorCopies(getOperation(), *options)))
102  signalPassFailure();
103  } else {
105  options.allowReturnAllocs = allowReturnAllocs;
106  options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
107  options.createDeallocs = createDeallocs;
108  if (mustInferMemorySpace)
109  options.defaultMemorySpace = None;
110  if (failed(insertTensorCopies(getOperation(), options)))
111  signalPassFailure();
112  }
113  }
114 
115 private:
117 };
118 } // namespace
119 
121  return std::make_unique<TensorCopyInsertionPass>();
122 }
123 
126  return std::make_unique<TensorCopyInsertionPass>(options);
127 }
Include the generated interface declarations.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state)
Analyze op and its nested ops.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This is a value defined by a result of an operation.
Definition: Value.h:425
bool allowReturnAllocs
Specifies whether returning newly allocated memrefs should be allowed.
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
std::unique_ptr< Pass > createTensorCopyInsertionPass()
Create a pass that resolves out-of-place tensor OpOperands with copies.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Definition: Operation.h:574
AnalysisState provides a variety of helper functions for dealing with tensor values.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:385
LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state)
Analyze moduleOp and its nested ops.
bool createDeallocs
Specifies whether dealloc ops should be generated along with alloc ops.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not...
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
result_range getOpResults()
Definition: Operation.h:337
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:584
static llvm::ManagedStatic< PassManagerOptions > options
static WalkResult skip()
Definition: Visitors.h:52
virtual bool isTensorYielded(Value tensor) const
Return true if the given tensor (or an aliasing tensor) is yielded from the containing block...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:395
Optional< unsigned > defaultMemorySpace
The default memory space that should be used when it cannot be inferred from the context.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
State for analysis-enabled bufferization.
Options for analysis-enabled bufferization.