23 struct ConstantOpInterface
24 :
public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
29 auto constantOp = cast<arith::ConstantOp>(op);
30 auto type = dyn_cast<RankedTensorType>(constantOp.getType());
37 if (
auto memSpace =
options.defaultMemorySpaceFn(type))
38 memorySpace = *memSpace;
40 return constantOp->emitError(
"could not infer memory space");
43 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
49 FailureOr<memref::GlobalOp> globalOp =
51 options.bufferAlignment, memorySpace);
54 memref::GlobalOp globalMemref = *globalOp;
55 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
56 rewriter, op, globalMemref.getType(), globalMemref.getName());
64 assert(isa<OpResult>(value));
69 struct IndexCastOpInterface
70 :
public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
90 auto castOp = cast<arith::IndexCastOp>(op);
91 auto resultTensorType = cast<TensorType>(castOp.getType());
93 FailureOr<Value> source =
97 auto sourceType = cast<BaseMemRefType>(source->getType());
101 if (
auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) {
103 rankedMemRefType.getShape(), resultTensorType.getElementType(),
104 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
106 auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType);
108 unrankedMemrefType.getMemorySpace());
111 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
118 struct SelectOpInterface
119 :
public BufferizableOpInterface::ExternalModel<SelectOpInterface,
140 auto selectOp = cast<arith::SelectOp>(op);
147 if (!selectOp.getCondition().getType().isInteger(1))
148 return op->
emitOpError(
"only i1 condition values are supported");
154 FailureOr<Value> maybeTrueBuffer =
156 FailureOr<Value> maybeFalseBuffer =
158 if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
160 Value trueBuffer = *maybeTrueBuffer;
161 Value falseBuffer = *maybeFalseBuffer;
169 if (failed(targetType))
171 if (trueBuffer.
getType() != *targetType)
173 rewriter.
create<memref::CastOp>(loc, *targetType, trueBuffer);
174 if (falseBuffer.
getType() != *targetType)
176 rewriter.
create<memref::CastOp>(loc, *targetType, falseBuffer);
179 replaceOpWithNewBufferizedOp<arith::SelectOp>(
180 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
184 FailureOr<BaseMemRefType>
188 auto selectOp = cast<arith::SelectOp>(op);
189 assert(value == selectOp.getResult() &&
"invalid value");
191 selectOp.getTrueValue(),
options, state, invocationStack);
193 selectOp.getFalseValue(),
options, state, invocationStack);
194 if (failed(trueType) || failed(falseType))
196 if (*trueType == *falseType)
198 if (trueType->getMemorySpace() != falseType->getMemorySpace())
199 return op->
emitError(
"inconsistent memory space on true/false operands");
203 auto memrefType = llvm::cast<MemRefType>(*trueType);
206 memrefType.getElementType()),
207 memrefType.getMemorySpace());
216 ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
217 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
218 SelectOp::attachInterface<SelectOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
BufferizationState provides information about the state of the IR during the bufferization process.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, SymbolTableCollection &symbolTables, uint64_t alignment, Attribute memorySpace={})
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.