20 struct CastOpInterface
21 :
public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
24 auto castOp = cast<CastOp>(op);
25 assert(value == castOp.getResult() &&
"invalid value");
27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29 cstr.
bound(value)[dim] == cstr.
getExpr(castOp.getSource(), dim);
35 :
public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
38 auto dimOp = cast<DimOp>(op);
39 assert(value == dimOp.getResult() &&
"invalid value");
41 auto constIndex = dimOp.getConstantIndex();
42 if (!constIndex.has_value())
44 cstr.
bound(value) == cstr.
getExpr(dimOp.getSource(), *constIndex);
48 struct EmptyOpInterface
49 :
public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
50 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
52 auto emptyOp = cast<EmptyOp>(op);
53 assert(value == emptyOp.getResult() &&
"invalid value");
55 cstr.
bound(value)[dim] == emptyOp.getMixedSizes()[dim];
59 struct ExtractSliceOpInterface
60 :
public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
62 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
64 auto extractSliceOp = cast<ExtractSliceOp>(op);
65 assert(value == extractSliceOp.getResult() &&
"invalid value");
67 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
69 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
74 cstr.
bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
78 llvm_unreachable(
"could not find non-rank-reduced dim");
83 :
public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
84 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
86 auto padOp = cast<PadOp>(op);
87 assert(value == padOp.getResult() &&
"invalid value");
92 cstr.
bound(value)[dim] == srcSize + lowPad + highPad;
96 struct RankOpInterface
97 :
public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
100 auto rankOp = cast<RankOp>(op);
101 assert(value == rankOp.getResult() &&
"invalid value");
104 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
107 cstr.
bound(value) == tensorType.getRank();
118 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
119 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
120 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
121 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
123 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
124 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
Base type for affine expression.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.