21 using namespace linalg;
27 static LogicalResult bufferizeDestinationStyleOpInterface(
35 if (op.hasPureBufferSemantics())
40 if (!op.hasPureTensorSemantics())
41 return op->emitError() <<
"op does not have pure tensor semantics";
45 newInputBuffers.reserve(op.getNumDpsInputs());
46 for (
OpOperand *opOperand : op.getDpsInputOperands()) {
47 if (op.isScalar(opOperand)) {
48 newInputBuffers.push_back(opOperand->get());
51 FailureOr<Value> buffer =
55 newInputBuffers.push_back(*buffer);
60 for (
OpResult opResult : op->getOpResults()) {
61 OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
62 FailureOr<Value> resultBuffer =
64 if (failed(resultBuffer))
66 newOutputBuffers.push_back(*resultBuffer);
71 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
78 assert(op->getNumRegions() == 1 &&
"expected that op has 1 region");
84 op->getRegion(0).getBlocks());
98 template <
typename OpTy>
99 struct LinalgOpInterface
105 auto linalgOp = cast<linalg::LinalgOp>(op);
106 return linalgOp.payloadUsesValueFromOperand(&opOperand);
112 auto dpsOp = cast<DestinationStyleOpInterface>(op);
113 return dpsOp.isDpsInit(&opOperand);
118 auto linalgOp = cast<linalg::LinalgOp>(op);
125 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
130 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
131 "unexpected number of indexing maps");
132 for (
auto [operand, map] :
133 llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
136 if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
139 if (!llvm::is_contained(opOperands, &operand))
143 if (!map.isIdentity())
153 return bufferizeDestinationStyleOpInterface(
154 rewriter, cast<DestinationStyleOpInterface>(op),
options, state);
160 template <
typename... Ops>
161 struct LinalgOpInterfaceHelper {
162 static void registerOpInterface(
MLIRContext *ctx) {
163 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
167 struct SoftmaxOpInterface
173 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
174 return &opOperand == &softmaxOp.getInputMutable();
180 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
181 FailureOr<Value> inputBuffer =
183 if (failed(inputBuffer))
185 FailureOr<Value> outputBuffer =
187 if (failed(outputBuffer))
189 rewriter.
create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
191 *outputBuffer, softmaxOp.getDimension());
204 LinalgOpInterfaceHelper<
206 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
208 >::registerOpInterface(ctx);
210 SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
BufferizationState provides information about the state of the IR during the bufferization process.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...