20static AffineExpr getTripCountExpr(OpFoldResult lb, OpFoldResult ub,
22 ValueBoundsConstraintSet &cstr) {
23 AffineExpr lbExpr = cstr.
getExpr(lb);
24 AffineExpr ubExpr = cstr.
getExpr(ub);
25 AffineExpr stepExpr = cstr.
getExpr(step);
26 AffineExpr tripCountExpr =
27 AffineExpr(ubExpr - lbExpr).
ceilDiv(stepExpr);
31static void populateIVBounds(OpFoldResult lb, OpFoldResult ub,
32 OpFoldResult step, Value iv,
33 ValueBoundsConstraintSet &cstr) {
40 AffineExpr tripCountMinusOne =
41 getTripCountExpr(lb, ub, step, cstr) - cstr.
getExpr(1);
42 AffineExpr computedUpperBound =
43 cstr.
getExpr(lb) + AffineExpr(tripCountMinusOne * cstr.
getExpr(step));
44 cstr.
bound(iv) <= computedUpperBound;
48 :
public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
72 static void populateIterArgBounds(scf::ForOp forOp, Value value,
73 std::optional<int64_t> dim,
74 ValueBoundsConstraintSet &cstr) {
77 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
78 iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
80 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
83 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
84 .getOperand(iterArgIdx);
85 Value iterArg = forOp.getRegionIterArg(iterArgIdx);
86 Value initArg = forOp.getInitArgs()[iterArgIdx];
94 if (dim.has_value()) {
101 if (dim.has_value() || isa<BlockArgument>(value))
107 AffineExpr tripCountExpr = getTripCountExpr(
108 forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), cstr);
109 AffineExpr oneIterAdvanceExpr =
112 cstr.
getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
115 void populateBoundsForIndexValue(Operation *op, Value value,
116 ValueBoundsConstraintSet &cstr)
const {
117 auto forOp = cast<ForOp>(op);
119 if (value == forOp.getInductionVar()) {
120 return populateIVBounds(forOp.getLowerBound(), forOp.getUpperBound(),
121 forOp.getStep(), value, cstr);
125 populateIterArgBounds(forOp, value, std::nullopt, cstr);
128 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
129 ValueBoundsConstraintSet &cstr)
const {
130 auto forOp = cast<ForOp>(op);
132 populateIterArgBounds(forOp, value, dim, cstr);
136struct ForallOpInterface
137 :
public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
140 void populateBoundsForIndexValue(Operation *op, Value value,
141 ValueBoundsConstraintSet &cstr)
const {
142 auto forallOp = cast<ForallOp>(op);
147 auto blockArg = cast<BlockArgument>(value);
148 assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
149 "expected index value to be an induction var");
150 int64_t idx = blockArg.getArgNumber();
151 return populateIVBounds(forallOp.getMixedLowerBound()[idx],
152 forallOp.getMixedUpperBound()[idx],
153 forallOp.getMixedStep()[idx], value, cstr);
156 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
157 ValueBoundsConstraintSet &cstr)
const {
158 auto forallOp = cast<ForallOp>(op);
162 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
163 iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
165 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
170 Value outputOperand = forallOp.getOutputs()[iterArgIdx];
171 cstr.
bound(value)[dim] == cstr.
getExpr(outputOperand, dim);
176 :
public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
178 static void populateBounds(scf::IfOp ifOp, Value value,
179 std::optional<int64_t> dim,
180 ValueBoundsConstraintSet &cstr) {
181 unsigned int resultNum = cast<OpResult>(value).getResultNumber();
182 Value thenValue = ifOp.thenYield().getResults()[resultNum];
183 Value elseValue = ifOp.elseYield().getResults()[resultNum];
185 auto boundsBuilder = cstr.
bound(value);
198 cstr.
bound(value)[*dim] >= cstr.
getExpr(thenValue, dim);
199 cstr.
bound(value)[*dim] <= cstr.
getExpr(elseValue, dim);
201 cstr.
bound(value) >= thenValue;
202 cstr.
bound(value) <= elseValue;
213 cstr.
bound(value)[*dim] >= cstr.
getExpr(elseValue, dim);
214 cstr.
bound(value)[*dim] <= cstr.
getExpr(thenValue, dim);
216 cstr.
bound(value) >= elseValue;
217 cstr.
bound(value) <= thenValue;
222 void populateBoundsForIndexValue(Operation *op, Value value,
223 ValueBoundsConstraintSet &cstr)
const {
224 populateBounds(cast<IfOp>(op), value, std::nullopt, cstr);
227 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
228 ValueBoundsConstraintSet &cstr)
const {
229 populateBounds(cast<IfOp>(op), value, dim, cstr);
240 scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
241 scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
242 scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
AffineExpr ceilDiv(uint64_t v) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.