20 struct CastOpInterface
21 :
public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
24 auto castOp = cast<CastOp>(op);
25 assert(value == castOp.getResult() &&
"invalid value");
27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29 cstr.
bound(value)[dim] == cstr.
getExpr(castOp.getSource(), dim);
35 :
public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
38 auto dimOp = cast<DimOp>(op);
39 assert(value == dimOp.getResult() &&
"invalid value");
41 cstr.
bound(value) >= 0;
42 auto constIndex = dimOp.getConstantIndex();
43 if (!constIndex.has_value())
45 cstr.
bound(value) == cstr.
getExpr(dimOp.getSource(), *constIndex);
49 struct EmptyOpInterface
50 :
public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
51 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
53 auto emptyOp = cast<EmptyOp>(op);
54 assert(value == emptyOp.getResult() &&
"invalid value");
56 cstr.
bound(value)[dim] == emptyOp.getMixedSizes()[dim];
60 struct ExtractSliceOpInterface
61 :
public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
63 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
65 auto extractSliceOp = cast<ExtractSliceOp>(op);
66 assert(value == extractSliceOp.getResult() &&
"invalid value");
68 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
70 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
75 cstr.
bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
79 llvm_unreachable(
"could not find non-rank-reduced dim");
84 :
public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
85 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
87 auto padOp = cast<PadOp>(op);
88 assert(value == padOp.getResult() &&
"invalid value");
93 cstr.
bound(value)[dim] == srcSize + lowPad + highPad;
97 struct RankOpInterface
98 :
public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
101 auto rankOp = cast<RankOp>(op);
102 assert(value == rankOp.getResult() &&
"invalid value");
105 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
108 cstr.
bound(value) == tensorType.getRank();
119 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
120 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
121 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
122 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
124 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
125 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
Base type for affine expression.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.