20template <
typename OpTy>
21struct AllocOpInterface
22 :
public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
24 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25 ValueBoundsConstraintSet &cstr)
const {
26 auto allocOp = cast<OpTy>(op);
27 assert(value == allocOp.getResult() &&
"invalid value");
29 cstr.
bound(value)[dim] == allocOp.getMixedSizes()[dim];
34 :
public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36 ValueBoundsConstraintSet &cstr)
const {
37 auto castOp = cast<CastOp>(op);
38 assert(value == castOp.getResult() &&
"invalid value");
40 if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41 llvm::isa<MemRefType>(castOp.getSource().getType())) {
42 cstr.
bound(value)[dim] == cstr.
getExpr(castOp.getSource(), dim);
48 :
public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49 void populateBoundsForIndexValue(Operation *op, Value value,
50 ValueBoundsConstraintSet &cstr)
const {
51 auto dimOp = cast<DimOp>(op);
52 assert(value == dimOp.getResult() &&
"invalid value");
54 cstr.
bound(value) >= 0;
55 auto constIndex = dimOp.getConstantIndex();
56 if (!constIndex.has_value())
58 cstr.
bound(value) == cstr.
getExpr(dimOp.getSource(), *constIndex);
62struct ExpandShapeOpInterface
63 :
public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
64 memref::ExpandShapeOp> {
65 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66 ValueBoundsConstraintSet &cstr)
const {
67 auto expandOp = cast<memref::ExpandShapeOp>(op);
68 assert(value == expandOp.getResult() &&
"invalid value");
69 cstr.
bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
73struct GetGlobalOpInterface
74 :
public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
76 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
77 ValueBoundsConstraintSet &cstr)
const {
78 auto getGlobalOp = cast<GetGlobalOp>(op);
79 assert(value == getGlobalOp.getResult() &&
"invalid value");
81 auto type = getGlobalOp.getType();
82 assert(!type.isDynamicDim(dim) &&
"expected static dim");
83 cstr.
bound(value)[dim] == type.getDimSize(dim);
88 :
public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
89 void populateBoundsForIndexValue(Operation *op, Value value,
90 ValueBoundsConstraintSet &cstr)
const {
91 auto rankOp = cast<RankOp>(op);
92 assert(value == rankOp.getResult() &&
"invalid value");
94 auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
97 cstr.
bound(value) == memrefType.getRank();
101struct CollapseShapeOpInterface
102 :
public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
103 memref::CollapseShapeOp> {
104 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
105 ValueBoundsConstraintSet &cstr)
const {
106 auto collapseOp = cast<memref::CollapseShapeOp>(op);
107 assert(value == collapseOp.getResult() &&
"invalid value");
111 collapseOp.getReassociationIndices()[dim];
112 AffineExpr productExpr =
113 cstr.
getExpr(collapseOp.getSrc(), reassocIndices[0]);
114 for (
size_t i = 1; i < reassocIndices.size(); ++i) {
116 productExpr * cstr.
getExpr(collapseOp.getSrc(), reassocIndices[i]);
118 cstr.
bound(value)[dim] == productExpr;
122struct SubViewOpInterface
123 :
public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
125 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
126 ValueBoundsConstraintSet &cstr)
const {
127 auto subViewOp = cast<SubViewOp>(op);
128 assert(value == subViewOp.getResult() &&
"invalid value");
130 llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
132 for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
134 if (!dropped.test(i))
137 cstr.
bound(value)[dim] == subViewOp.getMixedSizes()[i];
141 llvm_unreachable(
"could not find non-rank-reduced dim");
152 memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
154 memref::AllocaOp::attachInterface<
155 memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
156 memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
157 memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
158 memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
160 memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
162 memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
163 memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
164 memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
SmallVector< int64_t, 2 > ReassociationIndices