22 struct ConstantOpInterface
23 :
public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
27 auto constantOp = cast<arith::ConstantOp>(op);
32 return op->
emitError(
"memory space not implemented yet");
35 if (!constantOp.getType().isa<RankedTensorType>())
39 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
49 memref::GlobalOp globalMemref = *globalOp;
50 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
51 rewriter, op, globalMemref.getType(), globalMemref.getName());
64 struct IndexCastOpInterface
65 :
public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
89 auto castOp = cast<arith::IndexCastOp>(op);
90 auto resultTensorType = castOp.getType().cast<
TensorType>();
99 if (
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
100 resultType = MemRefType::get(
101 rankedMemRefType.getShape(), resultTensorType.getElementType(),
102 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
105 resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
106 unrankedMemrefType.getMemorySpace());
109 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
116 struct SelectOpInterface
117 :
public BufferizableOpInterface::ExternalModel<SelectOpInterface,
143 auto selectOp = cast<arith::SelectOp>(op);
154 if (
failed(maybeTrueBuffer) ||
failed(maybeFalseBuffer))
156 Value trueBuffer = *maybeTrueBuffer;
157 Value falseBuffer = *maybeFalseBuffer;
168 rewriter.
create<memref::CastOp>(loc, *targetType, trueBuffer);
170 rewriter.
create<memref::CastOp>(loc, *targetType, falseBuffer);
173 replaceOpWithNewBufferizedOp<arith::SelectOp>(
174 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
181 auto selectOp = cast<arith::SelectOp>(op);
182 assert(value == selectOp.getResult() &&
"invalid value");
189 if (*trueType == *falseType)
191 if (trueType->getMemorySpace() != falseType->getMemorySpace())
192 return op->
emitError(
"inconsistent memory space on true/false operands");
196 auto memrefType = trueType->cast<MemRefType>();
198 RankedTensorType::get(memrefType.getShape(),
199 memrefType.getElementType()),
200 memrefType.getMemorySpace());
214 ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
215 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
216 SelectOp::attachInterface<SelectOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpOperand & getOpOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment)
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.