21 :
public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
45 static void populateIterArgBounds(scf::ForOp forOp,
Value value,
46 std::optional<int64_t> dim,
50 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
51 iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
53 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
56 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
57 .getOperand(iterArgIdx);
58 Value iterArg = forOp.getRegionIterArg(iterArgIdx);
59 Value initArg = forOp.getInitArgs()[iterArgIdx];
65 ValueBoundsConstraintSet::ComparisonOperator::EQ,
67 if (dim.has_value()) {
77 auto forOp = cast<ForOp>(op);
79 if (value == forOp.getInductionVar()) {
81 cstr.
bound(value) >= forOp.getLowerBound();
82 cstr.
bound(value) < forOp.getUpperBound();
87 populateIterArgBounds(forOp, value, std::nullopt, cstr);
90 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
92 auto forOp = cast<ForOp>(op);
94 populateIterArgBounds(forOp, value, dim, cstr);
98 struct ForallOpInterface
99 :
public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
104 auto forallOp = cast<ForallOp>(op);
109 auto blockArg = cast<BlockArgument>(value);
110 assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
111 "expected index value to be an induction var");
112 int64_t idx = blockArg.getArgNumber();
116 cstr.
bound(value) >= lb;
117 cstr.
bound(value) < ub;
120 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
122 auto forallOp = cast<ForallOp>(op);
126 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
127 iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
129 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
134 Value outputOperand = forallOp.getOutputs()[iterArgIdx];
135 cstr.
bound(value)[dim] == cstr.
getExpr(outputOperand, dim);
140 :
public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
142 static void populateBounds(scf::IfOp ifOp,
Value value,
143 std::optional<int64_t> dim,
145 unsigned int resultNum = cast<OpResult>(value).getResultNumber();
146 Value thenValue = ifOp.thenYield().getResults()[resultNum];
147 Value elseValue = ifOp.elseYield().getResults()[resultNum];
149 auto boundsBuilder = cstr.
bound(value);
159 ValueBoundsConstraintSet::ComparisonOperator::LE,
162 cstr.
bound(value)[*dim] >= cstr.
getExpr(thenValue, dim);
163 cstr.
bound(value)[*dim] <= cstr.
getExpr(elseValue, dim);
165 cstr.
bound(value) >= thenValue;
166 cstr.
bound(value) <= elseValue;
174 ValueBoundsConstraintSet::ComparisonOperator::LE,
177 cstr.
bound(value)[*dim] >= cstr.
getExpr(elseValue, dim);
178 cstr.
bound(value)[*dim] <= cstr.
getExpr(thenValue, dim);
180 cstr.
bound(value) >= elseValue;
181 cstr.
bound(value) <= thenValue;
188 populateBounds(cast<IfOp>(op), value, std::nullopt, cstr);
191 void populateBoundsForShapedValueDim(
Operation *op,
Value value, int64_t dim,
193 populateBounds(cast<IfOp>(op), value, dim, cstr);
204 scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
205 scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
206 scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
Base type for affine expression.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.