32 uint64_t udiv = intVal.getZExtValue();
33 uint64_t sdiv = std::abs(intVal.getSExtValue());
43class AffineExprDivisibilityFinder
45 ConstantIntDivisibility> {
47 using ExprDivisibilityMap =
48 llvm::DenseMap<AffineExpr, ConstantIntDivisibility>;
49 AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap)
50 : divisibilityMap(divisibilityMap) {}
52 ConstantIntDivisibility visitConstantExpr(AffineConstantExpr expr) {
54 uint64_t constValue = std::abs(expr.
getValue());
55 return ConstantIntDivisibility(constValue, constValue);
58 ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) {
62 if (divisibilityMap.contains(expr))
63 return divisibilityMap[expr];
67 ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) {
71 if (divisibilityMap.contains(expr))
72 return divisibilityMap[expr];
78 ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) {
79 if (divisibilityMap.contains(expr))
80 return divisibilityMap[expr];
83 ConstantIntDivisibility lhsDiv =
visit(expr.
getLHS());
84 ConstantIntDivisibility rhsDiv =
visit(expr.
getRHS());
90 ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) {
91 if (divisibilityMap.contains(expr))
92 return divisibilityMap[expr];
95 ConstantIntDivisibility lhsDiv =
visit(expr.
getLHS());
96 ConstantIntDivisibility rhsDiv =
visit(expr.
getRHS());
97 return ConstantIntDivisibility(lhsDiv.
udiv() * rhsDiv.
udiv(),
101 ConstantIntDivisibility visitFloorDivExpr(AffineBinaryOpExpr expr) {
102 return visitDivExpr(expr);
105 ConstantIntDivisibility visitCeilDivExpr(AffineBinaryOpExpr expr) {
106 return visitDivExpr(expr);
114 ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) {
115 if (divisibilityMap.contains(expr))
116 return divisibilityMap[expr];
117 auto constRhs = dyn_cast<AffineConstantExpr>(expr.
getRHS());
118 if (!constRhs || constRhs.getValue() == 0)
119 return ConstantIntDivisibility(1, 1);
120 auto constValue =
static_cast<uint64_t
>(std::abs(constRhs.getValue()));
121 ConstantIntDivisibility lhsDiv =
visit(expr.
getLHS());
124 uint64_t modUDiv = (lhsDiv.
udiv() % constValue == 0)
126 : std::gcd(lhsDiv.
udiv(), constValue);
127 uint64_t modSDiv = (lhsDiv.
sdiv() % constValue == 0)
129 : std::gcd(lhsDiv.
sdiv(), constValue);
130 return ConstantIntDivisibility(modUDiv, modSDiv);
134 ConstantIntDivisibility visitInvalidExpr(AffineBinaryOpExpr expr) {
143 ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) {
144 if (divisibilityMap.contains(expr))
145 return divisibilityMap[expr];
146 auto constRhs = dyn_cast<AffineConstantExpr>(expr.
getRHS());
148 if (!constRhs || constRhs.getValue() == 0)
149 return ConstantIntDivisibility(1, 1);
150 auto constValue =
static_cast<uint64_t
>(std::abs(constRhs.getValue()));
151 ConstantIntDivisibility lhsDiv =
visit(expr.
getLHS());
153 lhsDiv.
udiv() % constValue == 0 ? lhsDiv.
udiv() / constValue : 1;
155 lhsDiv.
sdiv() % constValue == 0 ? lhsDiv.
sdiv() / constValue : 1;
156 return ConstantIntDivisibility(divUDiv, divSDiv);
159 ExprDivisibilityMap &divisibilityMap;
172 inputExprs.append(llvm::map_to_vector(
174 [&](
int64_t dim) { return getAffineDimExpr(dim, map.getContext()); }));
175 inputExprs.append(llvm::map_to_vector(
177 [&](
int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); }));
178 for (
auto [expr, divisibility] :
179 llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) {
180 exprDivisibilityMap[expr] = divisibility;
182 AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap);
187 resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr));
188 return resultDivisibilities;
195template <
typename MinOrMaxTy>
196void inferAffineMinOrMaxResultDivisibility(
199 static_assert(llvm::is_one_of<MinOrMaxTy, AffineMinOp, AffineMaxOp>::value,
200 "MinOrMaxTy must be AffineMinOp or AffineMaxOp");
202 for (
auto [operand, divisibility] :
203 llvm::zip(minOrMaxOp.getOperands(), argDivs)) {
204 operandDivisibilities.push_back(
209 getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities);
212 resultDivisibilities.pop_back_val();
213 for (
auto divisibility : resultDivisibilities)
214 resultDivisibility = resultDivisibility.
getUnion(divisibility);
215 setResultDivs(minOrMaxOp.getResult(), resultDivisibility);
220void AffineApplyOp::inferResultDivisibility(
223 for (
auto [operand, divisibility] : llvm::zip(getOperands(), argDivs)) {
224 operandDivisibilities.push_back(
229 getResultDivisibilities(getMap(), operandDivisibilities);
230 for (
auto [
result, divisibility] :
231 llvm::zip_equal(getOperation()->getResults(), resultDivisibilities)) {
232 setResultDivs(
result, divisibility);
238 inferAffineMinOrMaxResultDivisibility(*
this, argDivs, setResultDivs);
243 inferAffineMinOrMaxResultDivisibility(*
this, argDivs, setResultDivs);
246void AffineDelinearizeIndexOp::inferResultDivisibility(
255 int64_t numResults = getNumResults();
263 operandDivs.push_back(linearDiv);
268 for (
int64_t i = 0, e =
static_cast<int64_t>(staticBasis.size()); i < e;
270 if (ShapedType::isDynamic(staticBasis[i])) {
273 argDivs[1 + dynIdx]));
281 bool hasOuter = hasOuterBound();
282 int64_t basisStart = hasOuter ? 1 : 0;
295 for (
int64_t i = numResults - 1; i >= 0; --i) {
298 resultExpr = linearExpr.
floorDiv(stride);
301 (linearExpr.
floorDiv(stride)) % basisExprs[basisStart + i - 1];
306 getResultDivisibilities(resultMap, operandDivs);
307 setResultDivs(getResult(i), divs[0]);
310 stride = basisExprs[basisStart + i - 1] * stride;
static ConstantIntDivisibility getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
AffineExpr getLHS() const
AffineExpr getRHS() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Statically known divisibility information for an integer SSA value.
ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const
This lattice value represents the integer divisibility of an SSA value.
bool isUninitialized() const
const ConstantIntDivisibility & getValue() const
static IntegerDivisibility getMinDivisibility()
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
llvm::function_ref< void(Value, const ConstantIntDivisibility &)> SetIntDivisibilityFn
The type of the setResultDivs callback provided to ops implementing InferIntDivisibilityOpInterface.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)