25static LogicalResult bufferizeDestinationStyleOpInterface(
33 if (op.hasPureBufferSemantics())
38 if (!op.hasPureTensorSemantics())
39 return op->emitError() <<
"op does not have pure tensor semantics";
43 newInputBuffers.reserve(op.getNumDpsInputs());
44 for (
OpOperand *opOperand : op.getDpsInputOperands()) {
45 if (op.isScalar(opOperand)) {
46 newInputBuffers.push_back(opOperand->get());
49 FailureOr<Value> buffer =
50 getBuffer(rewriter, opOperand->get(),
options, state);
53 newInputBuffers.push_back(*buffer);
58 for (
OpResult opResult : op->getOpResults()) {
59 OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
60 FailureOr<Value> resultBuffer =
61 getBuffer(rewriter, opOperand->
get(),
options, state);
64 newOutputBuffers.push_back(*resultBuffer);
69 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
76 assert(op->getNumRegions() == 1 &&
"expected that op has 1 region");
82 op->getRegion(0).getBlocks());
89 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
96template <
typename OpTy>
97struct LinalgOpInterface
100 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
101 const AnalysisState &state)
const {
103 auto linalgOp = cast<linalg::LinalgOp>(op);
104 return linalgOp.payloadUsesValueFromOperand(&opOperand);
107 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
108 const AnalysisState &state)
const {
110 auto dpsOp = cast<DestinationStyleOpInterface>(op);
111 return dpsOp.isDpsInit(&opOperand);
114 bool bufferizesToElementwiseAccess(Operation *op,
const AnalysisState &state,
115 ArrayRef<OpOperand *> opOperands)
const {
116 auto linalgOp = cast<linalg::LinalgOp>(op);
123 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
127 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
128 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
129 "unexpected number of indexing maps");
130 for (
auto [operand, map] :
131 llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
134 if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
137 if (!llvm::is_contained(opOperands, &operand))
141 if (!map.isIdentity())
148 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
149 const BufferizationOptions &
options,
150 BufferizationState &state)
const {
151 return bufferizeDestinationStyleOpInterface(
152 rewriter, cast<DestinationStyleOpInterface>(op),
options, state);
158template <
typename... Ops>
159struct LinalgOpInterfaceHelper {
160 static void registerOpInterface(MLIRContext *ctx) {
161 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
165struct SoftmaxOpInterface
168 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
169 const AnalysisState &state)
const {
171 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
172 return &opOperand == &softmaxOp.getInputMutable();
175 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
176 const BufferizationOptions &
options,
177 BufferizationState &state)
const {
178 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
179 FailureOr<Value> inputBuffer =
180 getBuffer(rewriter, softmaxOp.getInput(),
options, state);
183 FailureOr<Value> outputBuffer =
184 getBuffer(rewriter, softmaxOp.getOutput(),
options, state);
187 linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(),
189 *outputBuffer, softmaxOp.getDimension());
190 replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
202 LinalgOpInterfaceHelper<
204#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
206 >::registerOpInterface(ctx);
208 SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...