21 using namespace linalg;
28 bufferizeDestinationStyleOpInterface(
RewriterBase &rewriter,
29 DestinationStyleOpInterface op,
36 if (op.hasPureBufferSemantics())
41 if (!op.hasPureTensorSemantics())
42 return op->emitError() <<
"op does not have pure tensor semantics";
46 newInputBuffers.reserve(op.getNumDpsInputs());
47 for (
OpOperand *opOperand : op.getDpsInputOperands()) {
48 if (op.isScalar(opOperand)) {
49 newInputBuffers.push_back(opOperand->get());
55 newInputBuffers.push_back(*buffer);
60 for (
OpResult opResult : op->getOpResults()) {
61 OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
62 FailureOr<Value> resultBuffer =
64 if (failed(resultBuffer))
66 newOutputBuffers.push_back(*resultBuffer);
71 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
78 assert(op->getNumRegions() == 1 &&
"expected that op has 1 region");
84 op->getRegion(0).getBlocks());
98 template <
typename OpTy>
99 struct LinalgOpInterface
105 auto linalgOp = cast<linalg::LinalgOp>(op);
106 return linalgOp.payloadUsesValueFromOperand(&opOperand);
112 auto dpsOp = cast<DestinationStyleOpInterface>(op);
113 return dpsOp.isDpsInit(&opOperand);
118 auto linalgOp = cast<linalg::LinalgOp>(op);
125 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
130 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
131 "unexpected number of indexing maps");
132 for (
auto [operand, map] :
133 llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
136 if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
139 if (!llvm::is_contained(opOperands, &operand))
143 if (!map.isIdentity())
152 return bufferizeDestinationStyleOpInterface(
153 rewriter, cast<DestinationStyleOpInterface>(op),
options);
159 template <
typename... Ops>
160 struct LinalgOpInterfaceHelper {
161 static void registerOpInterface(
MLIRContext *ctx) {
162 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
166 struct SoftmaxOpInterface
172 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
173 return &opOperand == &softmaxOp.getInputMutable();
178 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
179 FailureOr<Value> inputBuffer =
181 if (failed(inputBuffer))
183 FailureOr<Value> outputBuffer =
185 if (failed(outputBuffer))
187 rewriter.
create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
189 *outputBuffer, softmaxOp.getDimension());
202 LinalgOpInterfaceHelper<
204 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
205 >::registerOpInterface(ctx);
207 SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...