MLIR 22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Dialect.h"
15#include "mlir/IR/Operation.h"
17
18using namespace mlir;
19using namespace linalg;
20using namespace mlir::bufferization;
21
22namespace {
23
24/// Generic conversion for any DestinationStyleOpInterface on tensors.
25static LogicalResult bufferizeDestinationStyleOpInterface(
26 RewriterBase &rewriter, DestinationStyleOpInterface op,
27 const BufferizationOptions &options, const BufferizationState &state) {
28 // Take a guard before anything else.
29 OpBuilder::InsertionGuard g(rewriter);
30 rewriter.setInsertionPoint(op);
31
32 // Nothing to do. This op is already bufferized.
33 if (op.hasPureBufferSemantics())
34 return success();
35
36 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
37 // basis.
38 if (!op.hasPureTensorSemantics())
39 return op->emitError() << "op does not have pure tensor semantics";
40
41 // New input operands for the cloned op.
42 SmallVector<Value> newInputBuffers;
43 newInputBuffers.reserve(op.getNumDpsInputs());
44 for (OpOperand *opOperand : op.getDpsInputOperands()) {
45 if (op.isScalar(opOperand)) {
46 newInputBuffers.push_back(opOperand->get());
47 continue;
48 }
49 FailureOr<Value> buffer =
50 getBuffer(rewriter, opOperand->get(), options, state);
51 if (failed(buffer))
52 return failure();
53 newInputBuffers.push_back(*buffer);
54 }
55
56 // New output operands for the cloned op.
57 SmallVector<Value> newOutputBuffers;
58 for (OpResult opResult : op->getOpResults()) {
59 OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
60 FailureOr<Value> resultBuffer =
61 getBuffer(rewriter, opOperand->get(), options, state);
62 if (failed(resultBuffer))
63 return failure();
64 newOutputBuffers.push_back(*resultBuffer);
65 }
66
67 // Merge input/output operands.
68 SmallVector<Value> newOperands = newInputBuffers;
69 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
70
71 // Set insertion point now that potential alloc/dealloc are introduced.
72 rewriter.setInsertionPoint(op);
73 // Clone the op, but use the new operands. Move the existing block into the
74 // new op. Since the new op does not have any tensor results, it does not
75 // return anything.
76 assert(op->getNumRegions() == 1 && "expected that op has 1 region");
77 OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{},
78 op->getAttrs());
79 opState.addRegion();
80 Operation *newOp = Operation::create(opState);
81 newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
82 op->getRegion(0).getBlocks());
83
84 // We don't want the rewriter tracks an incomplete operation, so insert new
85 // operation after op was fully constructed.
86 rewriter.insert(newOp);
87
88 // Replace the results of the old op with the new output buffers.
89 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
90
91 return success();
92}
93
94/// Bufferization of linalg.generic. Replace with a new linalg.generic that
95/// operates entirely on memrefs.
96template <typename OpTy>
97struct LinalgOpInterface
98 : public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>,
99 OpTy> {
100 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
101 const AnalysisState &state) const {
102 // Operand is read if it is used in the computation.
103 auto linalgOp = cast<linalg::LinalgOp>(op);
104 return linalgOp.payloadUsesValueFromOperand(&opOperand);
105 }
106
107 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
108 const AnalysisState &state) const {
109 // Operand is written to if it is not an input/init.
110 auto dpsOp = cast<DestinationStyleOpInterface>(op);
111 return dpsOp.isDpsInit(&opOperand);
112 }
113
114 bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
115 ArrayRef<OpOperand *> opOperands) const {
116 auto linalgOp = cast<linalg::LinalgOp>(op);
117
118 // Accesses into sparse data structures are not necessarily elementwise.
120 return false;
121
122 // All loops must be parallel.
123 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
124 return false;
125
126 // All index maps of tensors must be identity maps.
127 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
128 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
129 "unexpected number of indexing maps");
130 for (auto [operand, map] :
131 llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
132 // Non-tensors do not participate in bufferization, so they can be
133 // ignored.
134 if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
135 continue;
136 // Only consider operands in `opOperands`.
137 if (!llvm::is_contained(opOperands, &operand))
138 continue;
139 // TODO: This could be generalized to other indexing maps. (All indexing
140 // must be the same.)
141 if (!map.isIdentity())
142 return false;
143 }
144
145 return true;
146 }
147
148 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
149 const BufferizationOptions &options,
150 BufferizationState &state) const {
151 return bufferizeDestinationStyleOpInterface(
152 rewriter, cast<DestinationStyleOpInterface>(op), options, state);
153 }
154};
155
156/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
157/// the `BufferizableOpInterface` with each of them.
158template <typename... Ops>
159struct LinalgOpInterfaceHelper {
160 static void registerOpInterface(MLIRContext *ctx) {
161 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
162 }
163};
164
165struct SoftmaxOpInterface
166 : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
167 linalg::SoftmaxOp> {
168 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
169 const AnalysisState &state) const {
170 // Output operand is not read.
171 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
172 return &opOperand == &softmaxOp.getInputMutable();
173 }
174
175 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
176 const BufferizationOptions &options,
177 BufferizationState &state) const {
178 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
179 FailureOr<Value> inputBuffer =
180 getBuffer(rewriter, softmaxOp.getInput(), options, state);
181 if (failed(inputBuffer))
182 return failure();
183 FailureOr<Value> outputBuffer =
184 getBuffer(rewriter, softmaxOp.getOutput(), options, state);
185 if (failed(outputBuffer))
186 return failure();
187 linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(),
188 /*result=*/TypeRange(), *inputBuffer,
189 *outputBuffer, softmaxOp.getDimension());
190 replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
191 return success();
192 }
193};
194} // namespace
195
197 DialectRegistry &registry) {
198 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
199 // Register all Linalg structured ops. `LinalgOp` is an interface and it is
200 // not possible to attach an external interface to an existing interface.
201 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
202 LinalgOpInterfaceHelper<
203#define GET_OP_LIST
204#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
205
206 >::registerOpInterface(ctx);
207
208 SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
209 });
210}
return success()
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
iterator begin()
Definition Region.h:55
BlockListType & getBlocks()
Definition Region.h:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...