21 :
public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
23 static AffineExpr getTripCountExpr(scf::ForOp forOp,
24 ValueBoundsConstraintSet &cstr) {
25 AffineExpr lbExpr = cstr.
getExpr(forOp.getLowerBound());
26 AffineExpr ubExpr = cstr.
getExpr(forOp.getUpperBound());
27 AffineExpr stepExpr = cstr.
getExpr(forOp.getStep());
28 AffineExpr tripCountExpr =
29 AffineExpr(ubExpr - lbExpr).
ceilDiv(stepExpr);
55 static void populateIterArgBounds(scf::ForOp forOp, Value value,
56 std::optional<int64_t> dim,
57 ValueBoundsConstraintSet &cstr) {
60 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
61 iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
63 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
66 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
67 .getOperand(iterArgIdx);
68 Value iterArg = forOp.getRegionIterArg(iterArgIdx);
69 Value initArg = forOp.getInitArgs()[iterArgIdx];
77 if (dim.has_value()) {
84 if (dim.has_value() || isa<BlockArgument>(value))
90 AffineExpr tripCountExpr = getTripCountExpr(forOp, cstr);
91 AffineExpr oneIterAdvanceExpr =
94 cstr.
getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
97 void populateBoundsForIndexValue(Operation *op, Value value,
98 ValueBoundsConstraintSet &cstr)
const {
99 auto forOp = cast<ForOp>(op);
101 if (value == forOp.getInductionVar()) {
102 cstr.
bound(value) >= forOp.getLowerBound();
103 cstr.
bound(value) < forOp.getUpperBound();
108 AffineExpr tripCountMinusOne =
109 getTripCountExpr(forOp, cstr) - cstr.
getExpr(1);
110 AffineExpr computedUpperBound =
111 cstr.
getExpr(forOp.getLowerBound()) +
112 AffineExpr(tripCountMinusOne * cstr.
getExpr(forOp.getStep()));
113 cstr.
bound(value) <= computedUpperBound;
118 populateIterArgBounds(forOp, value, std::nullopt, cstr);
121 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
122 ValueBoundsConstraintSet &cstr)
const {
123 auto forOp = cast<ForOp>(op);
125 populateIterArgBounds(forOp, value, dim, cstr);
129struct ForallOpInterface
130 :
public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
133 void populateBoundsForIndexValue(Operation *op, Value value,
134 ValueBoundsConstraintSet &cstr)
const {
135 auto forallOp = cast<ForallOp>(op);
140 auto blockArg = cast<BlockArgument>(value);
141 assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
142 "expected index value to be an induction var");
143 int64_t idx = blockArg.getArgNumber();
145 AffineExpr lb = cstr.
getExpr(forallOp.getMixedLowerBound()[idx]);
146 AffineExpr ub = cstr.
getExpr(forallOp.getMixedUpperBound()[idx]);
147 cstr.
bound(value) >= lb;
148 cstr.
bound(value) < ub;
151 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
152 ValueBoundsConstraintSet &cstr)
const {
153 auto forallOp = cast<ForallOp>(op);
157 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
158 iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
160 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
165 Value outputOperand = forallOp.getOutputs()[iterArgIdx];
166 cstr.
bound(value)[dim] == cstr.
getExpr(outputOperand, dim);
171 :
public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
173 static void populateBounds(scf::IfOp ifOp, Value value,
174 std::optional<int64_t> dim,
175 ValueBoundsConstraintSet &cstr) {
176 unsigned int resultNum = cast<OpResult>(value).getResultNumber();
177 Value thenValue = ifOp.thenYield().getResults()[resultNum];
178 Value elseValue = ifOp.elseYield().getResults()[resultNum];
180 auto boundsBuilder = cstr.
bound(value);
193 cstr.
bound(value)[*dim] >= cstr.
getExpr(thenValue, dim);
194 cstr.
bound(value)[*dim] <= cstr.
getExpr(elseValue, dim);
196 cstr.
bound(value) >= thenValue;
197 cstr.
bound(value) <= elseValue;
208 cstr.
bound(value)[*dim] >= cstr.
getExpr(elseValue, dim);
209 cstr.
bound(value)[*dim] <= cstr.
getExpr(thenValue, dim);
211 cstr.
bound(value) >= elseValue;
212 cstr.
bound(value) <= thenValue;
217 void populateBoundsForIndexValue(Operation *op, Value value,
218 ValueBoundsConstraintSet &cstr)
const {
219 populateBounds(cast<IfOp>(op), value, std::nullopt, cstr);
222 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
223 ValueBoundsConstraintSet &cstr)
const {
224 populateBounds(cast<IfOp>(op), value, dim, cstr);
235 scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
236 scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
237 scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
AffineExpr ceilDiv(uint64_t v) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.