MLIR 23.0.0git
InferIntDivisibilityOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- InferIntDivisibilityOpInterfaceImpl.cpp ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Direct implementations of `InferIntDivisibilityOpInterface` for affine ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Matchers.h"
17
18#include <cstdlib>
19#include <numeric>
20
21using namespace mlir;
22using namespace mlir::affine;
23
24namespace {
25
28 if (!divisibility.isUninitialized())
29 return divisibility.getValue();
30 APInt intVal;
31 if (matchPattern(v, m_ConstantInt(&intVal))) {
32 uint64_t udiv = intVal.getZExtValue();
33 uint64_t sdiv = std::abs(intVal.getSExtValue());
34 return ConstantIntDivisibility(udiv, sdiv);
35 }
36 return ConstantIntDivisibility(1, 1);
37}
38
39/// Visits affine expressions and recursively calculates the divisibilities of
40/// each subexpression. The final divisibilities of the expression and its
41/// subexpressions will be stored in the map for which a reference is provided
42/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`).
43class AffineExprDivisibilityFinder
44 : public AffineExprVisitor<AffineExprDivisibilityFinder,
45 ConstantIntDivisibility> {
46public:
47 using ExprDivisibilityMap =
48 llvm::DenseMap<AffineExpr, ConstantIntDivisibility>;
49 AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap)
50 : divisibilityMap(divisibilityMap) {}
51
52 ConstantIntDivisibility visitConstantExpr(AffineConstantExpr expr) {
53 // Constant expressions are trivial, since they are always static.
54 uint64_t constValue = std::abs(expr.getValue());
55 return ConstantIntDivisibility(constValue, constValue);
56 }
57
58 ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) {
59 // Dim expressions cannot be analyzed further, so return the divisibility
60 // in `divisibilityMap` if it has been populated by the caller, or fallback
61 // to the minimum divisibility.
62 if (divisibilityMap.contains(expr))
63 return divisibilityMap[expr];
65 }
66
67 ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) {
68 // Symbol expressions cannot be analyzed further, so return the divisibility
69 // in `divisibilityMap` if it has been populated by the caller, or fallback
70 // to the minimum divisibility.
71 if (divisibilityMap.contains(expr))
72 return divisibilityMap[expr];
74 }
75
76 /// Infer the divisibility of an addition or subtraction expression by
77 /// recursively visiting the LHS and RHS, and then unioning the results.
78 ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) {
79 if (divisibilityMap.contains(expr))
80 return divisibilityMap[expr];
81 // The divisibility of an addition is the GCD of its constituents'
82 // divisibilities.
83 ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
84 ConstantIntDivisibility rhsDiv = visit(expr.getRHS());
85 return lhsDiv.getUnion(rhsDiv);
86 }
87
88 /// Infer the divisibility of a multiplication expression by recursively
89 /// visiting the LHS and RHS, and then multiplying the results.
90 ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) {
91 if (divisibilityMap.contains(expr))
92 return divisibilityMap[expr];
93 // The divisibility of a multiplication is the product of its constituents'
94 // divisibilities.
95 ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
96 ConstantIntDivisibility rhsDiv = visit(expr.getRHS());
97 return ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(),
98 lhsDiv.sdiv() * rhsDiv.sdiv());
99 }
100
101 ConstantIntDivisibility visitFloorDivExpr(AffineBinaryOpExpr expr) {
102 return visitDivExpr(expr);
103 }
104
105 ConstantIntDivisibility visitCeilDivExpr(AffineBinaryOpExpr expr) {
106 return visitDivExpr(expr);
107 }
108
109 /// Infer the divisibility of a mod expression. If the RHS is a constant,
110 /// the result divisibility is gcd(lhs_divisibility, rhs_constant), since
111 /// (d * k) mod c is always divisible by gcd(d, c). Furthermore, if the
112 /// LHS divisibility is itself divisible by the constant (i.e., d % c == 0),
113 /// then (d * k) mod c is always zero, represented as divisibility 0.
114 ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) {
115 if (divisibilityMap.contains(expr))
116 return divisibilityMap[expr];
117 auto constRhs = dyn_cast<AffineConstantExpr>(expr.getRHS());
118 if (!constRhs || constRhs.getValue() == 0)
119 return ConstantIntDivisibility(1, 1);
120 auto constValue = static_cast<uint64_t>(std::abs(constRhs.getValue()));
121 ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
122 // If the LHS is always a multiple of constValue, x mod constValue is
123 // always zero. Divisibility 0 is the lattice top ("divides everything").
124 uint64_t modUDiv = (lhsDiv.udiv() % constValue == 0)
125 ? 0
126 : std::gcd(lhsDiv.udiv(), constValue);
127 uint64_t modSDiv = (lhsDiv.sdiv() % constValue == 0)
128 ? 0
129 : std::gcd(lhsDiv.sdiv(), constValue);
130 return ConstantIntDivisibility(modUDiv, modSDiv);
131 }
132
133private:
134 ConstantIntDivisibility visitInvalidExpr(AffineBinaryOpExpr expr) {
136 }
137
138 /// Helper shared by ceildiv and floordiv implementations. Returns the minimum
139 /// divisibility as a fallback if the divisor is not a constant, because the
140 /// divisibility cannot be inferred in this case. If the divisor is a
141 /// constant, then this function recursively visits the dividend, and returns
142 /// the quotient of the dividend's divisibility with the divisor.
143 ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) {
144 if (divisibilityMap.contains(expr))
145 return divisibilityMap[expr];
146 auto constRhs = dyn_cast<AffineConstantExpr>(expr.getRHS());
147 // Division by zero is undefined, so return the minimum divisibility.
148 if (!constRhs || constRhs.getValue() == 0)
149 return ConstantIntDivisibility(1, 1);
150 auto constValue = static_cast<uint64_t>(std::abs(constRhs.getValue()));
151 ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
152 uint64_t divUDiv =
153 lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1;
154 uint64_t divSDiv =
155 lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1;
156 return ConstantIntDivisibility(divUDiv, divSDiv);
157 }
158
159 ExprDivisibilityMap &divisibilityMap;
160};
161
162/// Returns the divisibilities of each AffineMap result based on the
163/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities`
164/// should contain the divisibilities of the dims, followed by the
165/// divisibilities of the symbols in ascending order by their positions.
166SmallVector<ConstantIntDivisibility> getResultDivisibilities(
167 AffineMap map,
168 ArrayRef<ConstantIntDivisibility> dimAndSymbolDivisibilities) {
169 // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities.
171 SmallVector<AffineExpr> inputExprs;
172 inputExprs.append(llvm::map_to_vector(
173 llvm::seq<int64_t>(map.getNumDims()),
174 [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); }));
175 inputExprs.append(llvm::map_to_vector(
176 llvm::seq<int64_t>(map.getNumSymbols()),
177 [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); }));
178 for (auto [expr, divisibility] :
179 llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) {
180 exprDivisibilityMap[expr] = divisibility;
181 }
182 AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap);
183
184 // Walk each result expression and compute their divisibilities.
185 SmallVector<ConstantIntDivisibility> resultDivisibilities;
186 for (AffineExpr resultExpr : map.getResults())
187 resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr));
188 return resultDivisibilities;
189}
190
191/// Infer the result divisibility of an affine.min or affine.max operation
192/// based on its operand divisibilities. The result divisibility is the GCD
193/// of the divisibilities of each of the affine map results, because the result
194/// of the affine.min/max op could be any of these results.
195template <typename MinOrMaxTy>
196void inferAffineMinOrMaxResultDivisibility(
197 MinOrMaxTy minOrMaxOp, ArrayRef<IntegerDivisibility> argDivs,
198 SetIntDivisibilityFn setResultDivs) {
199 static_assert(llvm::is_one_of<MinOrMaxTy, AffineMinOp, AffineMaxOp>::value,
200 "MinOrMaxTy must be AffineMinOp or AffineMaxOp");
201 SmallVector<ConstantIntDivisibility> operandDivisibilities;
202 for (auto [operand, divisibility] :
203 llvm::zip(minOrMaxOp.getOperands(), argDivs)) {
204 operandDivisibilities.push_back(
205 getDivisibilityOfOperand(operand, divisibility));
206 }
207
208 SmallVector<ConstantIntDivisibility> resultDivisibilities =
209 getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities);
210
211 ConstantIntDivisibility resultDivisibility =
212 resultDivisibilities.pop_back_val();
213 for (auto divisibility : resultDivisibilities)
214 resultDivisibility = resultDivisibility.getUnion(divisibility);
215 setResultDivs(minOrMaxOp.getResult(), resultDivisibility);
216}
217
218} // namespace
219
220void AffineApplyOp::inferResultDivisibility(
222 SmallVector<ConstantIntDivisibility> operandDivisibilities;
223 for (auto [operand, divisibility] : llvm::zip(getOperands(), argDivs)) {
224 operandDivisibilities.push_back(
225 getDivisibilityOfOperand(operand, divisibility));
226 }
227
228 SmallVector<ConstantIntDivisibility> resultDivisibilities =
229 getResultDivisibilities(getMap(), operandDivisibilities);
230 for (auto [result, divisibility] :
231 llvm::zip_equal(getOperation()->getResults(), resultDivisibilities)) {
232 setResultDivs(result, divisibility);
233 }
234}
235
236void AffineMinOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
237 SetIntDivisibilityFn setResultDivs) {
238 inferAffineMinOrMaxResultDivisibility(*this, argDivs, setResultDivs);
239}
240
241void AffineMaxOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
242 SetIntDivisibilityFn setResultDivs) {
243 inferAffineMinOrMaxResultDivisibility(*this, argDivs, setResultDivs);
244}
245
246void AffineDelinearizeIndexOp::inferResultDivisibility(
248 MLIRContext *ctx = getContext();
249
250 // Operands are: [linear_index, dynamic_basis_values...]
251 ConstantIntDivisibility linearDiv =
252 getDivisibilityOfOperand(getLinearIndex(), argDivs[0]);
253
254 ArrayRef<int64_t> staticBasis = getStaticBasis();
255 int64_t numResults = getNumResults();
256
257 // Build affine expressions for each result.
258 // Dim 0 = linear index, symbols = dynamic basis values.
259 AffineExpr linearExpr = getAffineDimExpr(0, ctx);
260
261 // Collect operand divisibilities: [linear_index_div, dynamic_basis_divs...]
263 operandDivs.push_back(linearDiv);
264
265 // Map static/dynamic basis values to affine expressions.
266 int64_t dynIdx = 0;
267 SmallVector<AffineExpr> basisExprs;
268 for (int64_t i = 0, e = static_cast<int64_t>(staticBasis.size()); i < e;
269 ++i) {
270 if (ShapedType::isDynamic(staticBasis[i])) {
271 basisExprs.push_back(getAffineSymbolExpr(dynIdx, ctx));
272 operandDivs.push_back(getDivisibilityOfOperand(getDynamicBasis()[dynIdx],
273 argDivs[1 + dynIdx]));
274 dynIdx++;
275 } else {
276 basisExprs.push_back(getAffineConstantExpr(staticBasis[i], ctx));
277 }
278 }
279
280 // The computation basis skips the outer bound if present.
281 bool hasOuter = hasOuterBound();
282 int64_t basisStart = hasOuter ? 1 : 0;
283
284 // Each result[i] can be expressed as an affine expression of the linear
285 // index using the effective basis (after dropping outer bound if present).
286 // Effective basis B[k] = basisExprs[basisStart + k], for k = 0..N-2.
287 // Stride s[i] = product of B[i..N-2] = product of
288 // basisExprs[basisStart+i .. end].
289 //
290 // result[0] = x floordiv s[0]
291 // result[i>0] = (x floordiv s[i]) mod B[i-1]
292 // For i=N-1, s[N-1]=1, so result[N-1] = x mod B[N-2].
293
294 AffineExpr stride = getAffineConstantExpr(1, ctx);
295 for (int64_t i = numResults - 1; i >= 0; --i) {
296 AffineExpr resultExpr;
297 if (i == 0) {
298 resultExpr = linearExpr.floorDiv(stride);
299 } else {
300 resultExpr =
301 (linearExpr.floorDiv(stride)) % basisExprs[basisStart + i - 1];
302 }
303
304 AffineMap resultMap = AffineMap::get(1, dynIdx, resultExpr, ctx);
306 getResultDivisibilities(resultMap, operandDivs);
307 setResultDivs(getResult(i), divs[0]);
308
309 if (i > 0)
310 stride = basisExprs[basisStart + i - 1] * stride;
311 }
312}
static ConstantIntDivisibility getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
b getContext())
AffineExpr getLHS() const
AffineExpr getRHS() const
int64_t getValue() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Statically known divisibility information for an integer SSA value.
ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const
This lattice value represents the integer divisibility of an SSA value.
const ConstantIntDivisibility & getValue() const
static IntegerDivisibility getMinDivisibility()
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::function_ref< void(Value, const ConstantIntDivisibility &)> SetIntDivisibilityFn
The type of the setResultDivs callback provided to ops implementing InferIntDivisibilityOpInterface.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)