23 struct ConstantOpInterface
24 :
public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
28 auto constantOp = cast<arith::ConstantOp>(op);
29 auto type = dyn_cast<RankedTensorType>(constantOp.getType());
36 if (
auto memSpace =
options.defaultMemorySpaceFn(type))
37 memorySpace = *memSpace;
39 return constantOp->emitError(
"could not infer memory space");
42 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
48 FailureOr<memref::GlobalOp> globalOp =
52 memref::GlobalOp globalMemref = *globalOp;
53 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
54 rewriter, op, globalMemref.getType(), globalMemref.getName());
62 assert(isa<OpResult>(value));
67 struct IndexCastOpInterface
68 :
public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
87 auto castOp = cast<arith::IndexCastOp>(op);
88 auto resultTensorType = cast<TensorType>(castOp.getType());
93 auto sourceType = cast<BaseMemRefType>(source->getType());
97 if (
auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) {
99 rankedMemRefType.getShape(), resultTensorType.getElementType(),
100 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
102 auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType);
104 unrankedMemrefType.getMemorySpace());
107 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
114 struct SelectOpInterface
115 :
public BufferizableOpInterface::ExternalModel<SelectOpInterface,
135 auto selectOp = cast<arith::SelectOp>(op);
142 if (!selectOp.getCondition().getType().isInteger(1))
143 return op->
emitOpError(
"only i1 condition values are supported");
149 FailureOr<Value> maybeTrueBuffer =
151 FailureOr<Value> maybeFalseBuffer =
153 if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
155 Value trueBuffer = *maybeTrueBuffer;
156 Value falseBuffer = *maybeFalseBuffer;
164 if (failed(targetType))
166 if (trueBuffer.
getType() != *targetType)
168 rewriter.
create<memref::CastOp>(loc, *targetType, trueBuffer);
169 if (falseBuffer.
getType() != *targetType)
171 rewriter.
create<memref::CastOp>(loc, *targetType, falseBuffer);
174 replaceOpWithNewBufferizedOp<arith::SelectOp>(
175 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
179 FailureOr<BaseMemRefType>
182 auto selectOp = cast<arith::SelectOp>(op);
183 assert(value == selectOp.getResult() &&
"invalid value");
188 if (failed(trueType) || failed(falseType))
190 if (*trueType == *falseType)
192 if (trueType->getMemorySpace() != falseType->getMemorySpace())
193 return op->
emitError(
"inconsistent memory space on true/false operands");
197 auto memrefType = llvm::cast<MemRefType>(*trueType);
200 memrefType.getElementType()),
201 memrefType.getMemorySpace());
210 ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
211 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
212 SelectOp::attachInterface<SelectOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace={})
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.