MLIR  17.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 namespace {
21 /// Bufferization of arith.constant. Replace with memref.get_global.
22 struct ConstantOpInterface
23  : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24  arith::ConstantOp> {
25  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
26  const BufferizationOptions &options) const {
27  auto constantOp = cast<arith::ConstantOp>(op);
28 
29  // TODO: Implement memory space for this op. E.g., by adding a memory_space
30  // attribute to ConstantOp.
31  if (options.defaultMemorySpace != Attribute())
32  return op->emitError("memory space not implemented yet");
33 
34  // Only ranked tensors are supported.
35  if (!constantOp.getType().isa<RankedTensorType>())
36  return failure();
37 
38  // Only constants inside a module are supported.
39  auto moduleOp = constantOp->getParentOfType<ModuleOp>();
40  if (!moduleOp)
41  return failure();
42 
43  // Create global memory segment and replace tensor with memref pointing to
44  // that memory segment.
46  getGlobalFor(constantOp, options.bufferAlignment);
47  if (failed(globalOp))
48  return failure();
49  memref::GlobalOp globalMemref = *globalOp;
50  replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
51  rewriter, op, globalMemref.getType(), globalMemref.getName());
52 
53  return success();
54  }
55 
56  bool isWritable(Operation *op, Value value,
57  const AnalysisState &state) const {
58  // Memory locations returned by memref::GetGlobalOp may not be written to.
59  assert(value.isa<OpResult>());
60  return false;
61  }
62 };
63 
64 struct IndexCastOpInterface
65  : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
66  arith::IndexCastOp> {
67  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
68  const AnalysisState &state) const {
69  return false;
70  }
71 
72  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
73  const AnalysisState &state) const {
74  return false;
75  }
76 
77  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
78  const AnalysisState &state) const {
79  return {op->getResult(0)};
80  }
81 
82  BufferRelation bufferRelation(Operation *op, OpResult opResult,
83  const AnalysisState &state) const {
85  }
86 
87  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88  const BufferizationOptions &options) const {
89  auto castOp = cast<arith::IndexCastOp>(op);
90  auto resultTensorType = castOp.getType().cast<TensorType>();
91 
92  FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
93  if (failed(source))
94  return failure();
95  auto sourceType = source->getType().cast<BaseMemRefType>();
96 
97  // Result type should have same layout and address space as the source type.
98  BaseMemRefType resultType;
99  if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
100  resultType = MemRefType::get(
101  rankedMemRefType.getShape(), resultTensorType.getElementType(),
102  rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
103  } else {
104  auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
105  resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
106  unrankedMemrefType.getMemorySpace());
107  }
108 
109  replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
110  *source);
111  return success();
112  }
113 };
114 
115 /// Bufferization of arith.select. Just replace the operands.
116 struct SelectOpInterface
117  : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
118  arith::SelectOp> {
119  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
120  const AnalysisState &state) const {
121  return false;
122  }
123 
124  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
125  const AnalysisState &state) const {
126  return false;
127  }
128 
129  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
130  const AnalysisState &state) const {
131  return {op->getOpResult(0) /*result*/};
132  }
133 
135  getAliasingOpOperand(Operation *op, OpResult opResult,
136  const AnalysisState &state) const {
137  return {&op->getOpOperand(1) /*true_value*/,
138  &op->getOpOperand(2) /*false_value*/};
139  }
140 
141  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
142  const BufferizationOptions &options) const {
143  auto selectOp = cast<arith::SelectOp>(op);
144  Location loc = selectOp.getLoc();
145 
146  // TODO: It would be more efficient to copy the result of the `select` op
147  // instead of its OpOperands. In the worst case, 2 copies are inserted at
148  // the moment (one for each tensor). When copying the op result, only one
149  // copy would be needed.
150  FailureOr<Value> maybeTrueBuffer =
151  getBuffer(rewriter, selectOp.getTrueValue(), options);
152  FailureOr<Value> maybeFalseBuffer =
153  getBuffer(rewriter, selectOp.getFalseValue(), options);
154  if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
155  return failure();
156  Value trueBuffer = *maybeTrueBuffer;
157  Value falseBuffer = *maybeFalseBuffer;
158 
159  // The "true" and the "false" operands must have the same type. If the
160  // buffers have different types, they differ only in their layout map. Cast
161  // both of them to the most dynamic MemRef type.
162  if (trueBuffer.getType() != falseBuffer.getType()) {
163  auto targetType =
164  bufferization::getBufferType(selectOp.getResult(), options);
165  if (failed(targetType))
166  return failure();
167  trueBuffer =
168  rewriter.create<memref::CastOp>(loc, *targetType, trueBuffer);
169  falseBuffer =
170  rewriter.create<memref::CastOp>(loc, *targetType, falseBuffer);
171  }
172 
173  replaceOpWithNewBufferizedOp<arith::SelectOp>(
174  rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
175  return success();
176  }
177 
180  const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
181  auto selectOp = cast<arith::SelectOp>(op);
182  assert(value == selectOp.getResult() && "invalid value");
183  auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
184  options, fixedTypes);
185  auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
186  options, fixedTypes);
187  if (failed(trueType) || failed(falseType))
188  return failure();
189  if (*trueType == *falseType)
190  return *trueType;
191  if (trueType->getMemorySpace() != falseType->getMemorySpace())
192  return op->emitError("inconsistent memory space on true/false operands");
193 
194  // If the buffers have different types, they differ only in their layout
195  // map.
196  auto memrefType = trueType->cast<MemRefType>();
198  RankedTensorType::get(memrefType.getShape(),
199  memrefType.getElementType()),
200  memrefType.getMemorySpace());
201  }
202 
203  BufferRelation bufferRelation(Operation *op, OpResult opResult,
204  const AnalysisState &state) const {
205  return BufferRelation::None;
206  }
207 };
208 
209 } // namespace
210 
212  DialectRegistry &registry) {
213  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
214  ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
215  IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
216  SelectOp::attachInterface<SelectOpInterface>(*ctx);
217  });
218 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:113
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents an operand of an operation.
Definition: Value.h:255
This is a value defined by a result of an operation.
Definition: Value.h:450
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getOpResult(unsigned idx)
Definition: Operation.h:382
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:349
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:368
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:224
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool isa() const
Definition: Value.h:98
Type getType() const
Return the type of this value.
Definition: Value.h:122
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment)
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.