23 struct ConstantOpInterface
24 :
public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
28 auto constantOp = cast<arith::ConstantOp>(op);
31 if (
options.defaultMemorySpace.has_value())
32 memorySpace = *
options.defaultMemorySpace;
34 return constantOp->emitError(
"could not infer memory space");
37 if (!isa<RankedTensorType>(constantOp.getType()))
41 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
51 memref::GlobalOp globalMemref = *globalOp;
52 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
53 rewriter, op, globalMemref.getType(), globalMemref.getName());
61 assert(isa<OpResult>(value));
66 struct IndexCastOpInterface
67 :
public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
86 auto castOp = cast<arith::IndexCastOp>(op);
87 auto resultTensorType = cast<TensorType>(castOp.getType());
92 auto sourceType = cast<BaseMemRefType>(source->getType());
96 if (
auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) {
98 rankedMemRefType.getShape(), resultTensorType.getElementType(),
99 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
101 auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType);
103 unrankedMemrefType.getMemorySpace());
106 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
113 struct SelectOpInterface
114 :
public BufferizableOpInterface::ExternalModel<SelectOpInterface,
134 auto selectOp = cast<arith::SelectOp>(op);
141 if (!selectOp.getCondition().getType().isInteger(1))
142 return op->
emitOpError(
"only i1 condition values are supported");
152 if (
failed(maybeTrueBuffer) ||
failed(maybeFalseBuffer))
154 Value trueBuffer = *maybeTrueBuffer;
155 Value falseBuffer = *maybeFalseBuffer;
165 if (trueBuffer.
getType() != *targetType)
167 rewriter.
create<memref::CastOp>(loc, *targetType, trueBuffer);
168 if (falseBuffer.
getType() != *targetType)
170 rewriter.
create<memref::CastOp>(loc, *targetType, falseBuffer);
173 replaceOpWithNewBufferizedOp<arith::SelectOp>(
174 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
181 auto selectOp = cast<arith::SelectOp>(op);
182 assert(value == selectOp.getResult() &&
"invalid value");
189 if (*trueType == *falseType)
191 if (trueType->getMemorySpace() != falseType->getMemorySpace())
192 return op->
emitError(
"inconsistent memory space on true/false operands");
196 auto memrefType = llvm::cast<MemRefType>(*trueType);
199 memrefType.getElementType()),
200 memrefType.getMemorySpace());
209 ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
210 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
211 SelectOp::attachInterface<SelectOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace={})
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.