19 using namespace linalg;
25 static LogicalResult bufferizeDestinationStyleOpInterface(
33 if (op.hasPureBufferSemantics())
38 if (!op.hasPureTensorSemantics())
39 return op->emitError() <<
"op does not have pure tensor semantics";
43 newInputBuffers.reserve(op.getNumDpsInputs());
44 for (
OpOperand *opOperand : op.getDpsInputOperands()) {
45 if (op.isScalar(opOperand)) {
46 newInputBuffers.push_back(opOperand->get());
49 FailureOr<Value> buffer =
53 newInputBuffers.push_back(*buffer);
58 for (
OpResult opResult : op->getOpResults()) {
59 OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
60 FailureOr<Value> resultBuffer =
62 if (failed(resultBuffer))
64 newOutputBuffers.push_back(*resultBuffer);
69 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
76 assert(op->getNumRegions() == 1 &&
"expected that op has 1 region");
82 op->getRegion(0).getBlocks());
96 template <
typename OpTy>
97 struct LinalgOpInterface
103 auto linalgOp = cast<linalg::LinalgOp>(op);
104 return linalgOp.payloadUsesValueFromOperand(&opOperand);
110 auto dpsOp = cast<DestinationStyleOpInterface>(op);
111 return dpsOp.isDpsInit(&opOperand);
116 auto linalgOp = cast<linalg::LinalgOp>(op);
123 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
128 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
129 "unexpected number of indexing maps");
130 for (
auto [operand, map] :
131 llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
134 if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
137 if (!llvm::is_contained(opOperands, &operand))
141 if (!map.isIdentity())
151 return bufferizeDestinationStyleOpInterface(
152 rewriter, cast<DestinationStyleOpInterface>(op),
options, state);
158 template <
typename... Ops>
159 struct LinalgOpInterfaceHelper {
160 static void registerOpInterface(
MLIRContext *ctx) {
161 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
165 struct SoftmaxOpInterface
171 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
172 return &opOperand == &softmaxOp.getInputMutable();
178 auto softmaxOp = cast<linalg::SoftmaxOp>(op);
179 FailureOr<Value> inputBuffer =
181 if (failed(inputBuffer))
183 FailureOr<Value> outputBuffer =
185 if (failed(outputBuffer))
187 linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(),
189 *outputBuffer, softmaxOp.getDimension());
202 LinalgOpInterfaceHelper<
204 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
206 >::registerOpInterface(ctx);
208 SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
BufferizationState provides information about the state of the IR during the bufferization process.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...