21 using namespace linalg;
28 bufferizeDestinationStyleOpInterface(
RewriterBase &rewriter,
29 DestinationStyleOpInterface op,
36 if (op.hasPureBufferSemantics())
41 if (!op.hasPureTensorSemantics())
42 return op->
emitError() <<
"op does not have pure tensor semantics";
46 newInputBuffers.reserve(op.getNumDpsInputs());
47 for (
OpOperand *opOperand : op.getDpsInputOperands()) {
48 if (op.isScalar(opOperand)) {
49 newInputBuffers.push_back(opOperand->get());
55 newInputBuffers.push_back(*buffer);
61 OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
66 newOutputBuffers.push_back(*resultBuffer);
71 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
78 assert(op->
getNumRegions() == 1 &&
"expected that op has 1 region");
82 newOp->getRegion(0).
begin());
92 template <
typename OpTy>
93 struct LinalgOpInterface
99 auto linalgOp = cast<linalg::LinalgOp>(op);
100 return linalgOp.payloadUsesValueFromOperand(&opOperand);
106 auto dpsOp = cast<DestinationStyleOpInterface>(op);
107 return dpsOp.isDpsInit(&opOperand);
112 auto linalgOp = cast<linalg::LinalgOp>(op);
119 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
124 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
125 "unexpected number of indexing maps");
126 for (
auto [operand, map] :
127 llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
130 if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
133 if (!llvm::is_contained(opOperands, &operand))
137 if (!map.isIdentity())
146 return bufferizeDestinationStyleOpInterface(
147 rewriter, cast<DestinationStyleOpInterface>(op),
options);
153 template <
typename... Ops>
154 struct LinalgOpInterfaceHelper {
155 static void registerOpInterface(
MLIRContext *ctx) {
156 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
167 LinalgOpInterfaceHelper<
169 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
170 >::registerOpInterface(ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getOpResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...