23struct ConstantOpInterface
24 :
public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
26 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27 const BufferizationOptions &
options,
28 BufferizationState &state)
const {
29 auto constantOp = cast<arith::ConstantOp>(op);
30 auto type = dyn_cast<RankedTensorType>(constantOp.getType());
36 Attribute memorySpace;
37 if (
auto memSpace =
options.defaultMemorySpaceFn(type))
38 memorySpace = *memSpace;
40 return constantOp->emitError(
"could not infer memory space");
43 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
49 FailureOr<memref::GlobalOp> globalOp =
51 options.bufferAlignment, memorySpace);
54 memref::GlobalOp globalMemref = *globalOp;
55 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
56 rewriter, op, globalMemref.getType(), globalMemref.getName());
61 bool isWritable(Operation *op, Value value,
62 const AnalysisState &state)
const {
64 assert(isa<OpResult>(value));
69struct IndexCastOpInterface
70 :
public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
72 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
73 const AnalysisState &state)
const {
77 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
78 const AnalysisState &state)
const {
82 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
83 const AnalysisState &state)
const {
84 return {{op->
getResult(0), BufferRelation::Equivalent}};
87 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88 const BufferizationOptions &
options,
89 BufferizationState &state)
const {
90 auto castOp = cast<arith::IndexCastOp>(op);
91 auto resultTensorType = cast<TensorType>(castOp.getType());
93 FailureOr<Value> source =
94 getBuffer(rewriter, castOp.getIn(),
options, state);
97 auto sourceType = cast<BaseMemRefType>(source->getType());
100 BaseMemRefType resultType;
101 if (
auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) {
102 resultType = MemRefType::get(
103 rankedMemRefType.getShape(), resultTensorType.getElementType(),
104 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
106 auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType);
107 resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
108 unrankedMemrefType.getMemorySpace());
111 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
118struct SelectOpInterface
119 :
public BufferizableOpInterface::ExternalModel<SelectOpInterface,
121 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
122 const AnalysisState &state)
const {
126 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
127 const AnalysisState &state)
const {
131 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
132 const AnalysisState &state)
const {
133 return {{op->
getOpResult(0) , BufferRelation::Equivalent,
137 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
138 const BufferizationOptions &
options,
139 BufferizationState &state)
const {
140 auto selectOp = cast<arith::SelectOp>(op);
141 Location loc = selectOp.getLoc();
147 if (!selectOp.getCondition().getType().isInteger(1))
148 return op->
emitOpError(
"only i1 condition values are supported");
154 FailureOr<Value> maybeTrueBuffer =
155 getBuffer(rewriter, selectOp.getTrueValue(),
options, state);
156 FailureOr<Value> maybeFalseBuffer =
157 getBuffer(rewriter, selectOp.getFalseValue(),
options, state);
158 if (
failed(maybeTrueBuffer) ||
failed(maybeFalseBuffer))
160 Value trueBuffer = *maybeTrueBuffer;
161 Value falseBuffer = *maybeFalseBuffer;
167 auto targetType = bufferization::detail::asMemRefType(
168 bufferization::getBufferType(selectOp.getResult(),
options, state));
171 if (trueBuffer.
getType() != *targetType)
173 memref::CastOp::create(rewriter, loc, *targetType, trueBuffer);
174 if (falseBuffer.
getType() != *targetType)
176 memref::CastOp::create(rewriter, loc, *targetType, falseBuffer);
179 replaceOpWithNewBufferizedOp<arith::SelectOp>(
180 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
184 FailureOr<BufferLikeType>
186 const BufferizationState &state,
187 SmallVector<Value> &invocationStack)
const {
188 auto selectOp = cast<arith::SelectOp>(op);
189 assert(value == selectOp.getResult() &&
"invalid value");
191 bufferization::detail::asMemRefType(bufferization::getBufferType(
192 selectOp.getTrueValue(),
options, state, invocationStack));
194 bufferization::detail::asMemRefType(bufferization::getBufferType(
195 selectOp.getFalseValue(),
options, state, invocationStack));
198 if (*trueType == *falseType)
199 return cast<BufferLikeType>(*trueType);
200 if (trueType->getMemorySpace() != falseType->getMemorySpace())
201 return op->
emitError(
"inconsistent memory space on true/false operands");
205 auto memrefType = llvm::cast<MemRefType>(*trueType);
206 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
207 RankedTensorType::get(memrefType.getShape(),
208 memrefType.getElementType()),
209 memrefType.getMemorySpace()));
218 ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
219 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
220 SelectOp::attachInterface<SelectOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
OpResult getOpResult(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Type getType() const
Return the type of this value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
FailureOr< memref::GlobalOp > getGlobalFor(arith::ConstantOp constantOp, SymbolTableCollection &symbolTables, uint64_t alignment, Attribute memorySpace={})
Include the generated interface declarations.