21 :
public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23 ValueBoundsConstraintSet &cstr)
const {
24 auto castOp = cast<CastOp>(op);
25 assert(value == castOp.getResult() &&
"invalid value");
27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29 cstr.
bound(value)[dim] == cstr.
getExpr(castOp.getSource(), dim);
34struct CollapseShapeOpInterface
35 :
public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
37 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
38 ValueBoundsConstraintSet &cstr)
const {
39 auto collapseOp = cast<CollapseShapeOp>(op);
40 assert(value == collapseOp.getResult() &&
"invalid value");
44 collapseOp.getReassociationIndices()[dim];
45 AffineExpr productExpr =
46 cstr.
getExpr(collapseOp.getSrc(), reassocIndices[0]);
47 for (
size_t i = 1; i < reassocIndices.size(); ++i) {
49 productExpr * cstr.
getExpr(collapseOp.getSrc(), reassocIndices[i]);
51 cstr.
bound(value)[dim] == productExpr;
56 :
public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
57 void populateBoundsForIndexValue(Operation *op, Value value,
58 ValueBoundsConstraintSet &cstr)
const {
59 auto dimOp = cast<DimOp>(op);
60 assert(value == dimOp.getResult() &&
"invalid value");
62 cstr.
bound(value) >= 0;
63 auto constIndex = dimOp.getConstantIndex();
64 if (!constIndex.has_value())
66 cstr.
bound(value) == cstr.
getExpr(dimOp.getSource(), *constIndex);
70struct EmptyOpInterface
71 :
public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
72 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
73 ValueBoundsConstraintSet &cstr)
const {
74 auto emptyOp = cast<EmptyOp>(op);
75 assert(value == emptyOp.getResult() &&
"invalid value");
77 cstr.
bound(value)[dim] == emptyOp.getMixedSizes()[dim];
81struct ExpandShapeOpInterface
82 :
public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
84 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
85 ValueBoundsConstraintSet &cstr)
const {
86 auto expandOp = cast<ExpandShapeOp>(op);
87 assert(value == expandOp.getResult() &&
"invalid value");
88 cstr.
bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
92struct ExtractSliceOpInterface
93 :
public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
95 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
96 ValueBoundsConstraintSet &cstr)
const {
97 auto extractSliceOp = cast<ExtractSliceOp>(op);
98 assert(value == extractSliceOp.getResult() &&
"invalid value");
100 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
102 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
104 if (!dropped.test(i))
107 cstr.
bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
111 llvm_unreachable(
"could not find non-rank-reduced dim");
116 :
public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
117 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
118 ValueBoundsConstraintSet &cstr)
const {
119 auto padOp = cast<PadOp>(op);
120 assert(value == padOp.getResult() &&
"invalid value");
122 AffineExpr srcSize = cstr.
getExpr(padOp.getSource(), dim);
123 AffineExpr lowPad = cstr.
getExpr(padOp.getMixedLowPad()[dim]);
124 AffineExpr highPad = cstr.
getExpr(padOp.getMixedHighPad()[dim]);
125 cstr.
bound(value)[dim] == srcSize + lowPad + highPad;
129struct RankOpInterface
130 :
public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
131 void populateBoundsForIndexValue(Operation *op, Value value,
132 ValueBoundsConstraintSet &cstr)
const {
133 auto rankOp = cast<RankOp>(op);
134 assert(value == rankOp.getResult() &&
"invalid value");
137 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
140 cstr.
bound(value) == tensorType.getRank();
151 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
152 tensor::CollapseShapeOp::attachInterface<tensor::CollapseShapeOpInterface>(
154 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
155 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
156 tensor::ExpandShapeOp::attachInterface<tensor::ExpandShapeOpInterface>(
158 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
160 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
161 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
SmallVector< int64_t, 2 > ReassociationIndices