MLIR  17.0.0git
AffineExprVisitor.h
Go to the documentation of this file.
1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the AffineExpr visitor class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_AFFINEEXPRVISITOR_H
14 #define MLIR_IR_AFFINEEXPRVISITOR_H
15 
16 #include "mlir/IR/AffineExpr.h"
17 #include "llvm/ADT/ArrayRef.h"
18 
19 namespace mlir {
20 
21 /// Base class for AffineExpr visitors/walkers.
22 ///
23 /// AffineExpr visitors are used when you want to perform different actions
24 /// for different kinds of AffineExprs without having to use lots of casts
25 /// and a big switch instruction.
26 ///
27 /// To define your own visitor, inherit from this class, specifying your
28 /// new type for the 'SubClass' template parameter, and "override" visitXXX
29 /// functions in your class. This class is defined in terms of statically
30 /// resolved overloading, not virtual functions.
31 ///
32 /// For example, here is a visitor that counts the number of for AffineDimExprs
33 /// in an AffineExpr.
34 ///
35 /// /// Declare the class. Note that we derive from AffineExprVisitor
36 /// /// instantiated with our new subclasses_ type.
37 ///
38 /// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
39 /// unsigned numDimExprs;
40 /// DimExprCounter() : numDimExprs(0) {}
41 /// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
42 /// };
43 ///
44 /// And this class would be used like this:
45 /// DimExprCounter dec;
46 /// dec.visit(affineExpr);
47 /// numDimExprs = dec.numDimExprs;
48 ///
49 /// AffineExprVisitor provides visit methods for the following binary affine
50 /// op expressions:
51 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
52 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
53 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these
54 /// methods will call the general AffineBinaryOpExpr method.
55 ///
56 /// In addition, visit methods are provided for the following affine
57 // expressions: AffineConstantExpr, AffineDimExpr, and
58 // AffineSymbolExpr.
59 ///
60 /// Note that if you don't implement visitXXX for some affine expression type,
61 /// the visitXXX method for Instruction superclass will be invoked.
62 ///
63 /// Note that this class is specifically designed as a template to avoid
64 /// virtual function call overhead. Defining and using a AffineExprVisitor is
65 /// just as efficient as having your own switch instruction over the instruction
66 /// opcode.
67 
68 template <typename SubClass, typename RetTy = void>
70  //===--------------------------------------------------------------------===//
71  // Interface code - This is the public interface of the AffineExprVisitor
72  // that you use to visit affine expressions...
73 public:
74  // Function to walk an AffineExpr (in post order).
75  RetTy walkPostOrder(AffineExpr expr) {
76  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
77  "Must instantiate with a derived type of AffineExprVisitor");
78  switch (expr.getKind()) {
79  case AffineExprKind::Add: {
80  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
81  walkOperandsPostOrder(binOpExpr);
82  return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
83  }
84  case AffineExprKind::Mul: {
85  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
86  walkOperandsPostOrder(binOpExpr);
87  return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
88  }
89  case AffineExprKind::Mod: {
90  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
91  walkOperandsPostOrder(binOpExpr);
92  return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
93  }
95  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
96  walkOperandsPostOrder(binOpExpr);
97  return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
98  }
100  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
101  walkOperandsPostOrder(binOpExpr);
102  return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
103  }
105  return static_cast<SubClass *>(this)->visitConstantExpr(
106  expr.cast<AffineConstantExpr>());
108  return static_cast<SubClass *>(this)->visitDimExpr(
109  expr.cast<AffineDimExpr>());
111  return static_cast<SubClass *>(this)->visitSymbolExpr(
112  expr.cast<AffineSymbolExpr>());
113  }
114  }
115 
116  // Function to visit an AffineExpr.
117  RetTy visit(AffineExpr expr) {
118  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
119  "Must instantiate with a derived type of AffineExprVisitor");
120  switch (expr.getKind()) {
121  case AffineExprKind::Add: {
122  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
123  return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
124  }
125  case AffineExprKind::Mul: {
126  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
127  return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
128  }
129  case AffineExprKind::Mod: {
130  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
131  return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
132  }
134  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
135  return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
136  }
138  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
139  return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
140  }
142  return static_cast<SubClass *>(this)->visitConstantExpr(
143  expr.cast<AffineConstantExpr>());
145  return static_cast<SubClass *>(this)->visitDimExpr(
146  expr.cast<AffineDimExpr>());
148  return static_cast<SubClass *>(this)->visitSymbolExpr(
149  expr.cast<AffineSymbolExpr>());
150  }
151  llvm_unreachable("Unknown AffineExpr");
152  }
153 
154  //===--------------------------------------------------------------------===//
155  // Visitation functions... these functions provide default fallbacks in case
156  // the user does not specify what to do for a particular instruction type.
157  // The default behavior is to generalize the instruction type to its subtype
158  // and try visiting the subtype. All of this should be inlined perfectly,
159  // because there are no virtual functions to get in the way.
160  //
161 
162  // Default visit methods. Note that the default op-specific binary op visit
163  // methods call the general visitAffineBinaryOpExpr visit method.
164  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
166  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
167  }
169  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
170  }
172  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
173  }
175  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
176  }
178  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
179  }
180  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
181  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
182  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
183 
184 private:
185  // Walk the operands - each operand is itself walked in post order.
186  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
187  walkPostOrder(expr.getLHS());
188  walkPostOrder(expr.getRHS());
189  }
190 };
191 
192 // This class is used to flatten a pure affine expression (AffineExpr,
193 // which is in a tree form) into a sum of products (w.r.t constants) when
194 // possible, and in that process simplifying the expression. For a modulo,
195 // floordiv, or a ceildiv expression, an additional identifier, called a local
196 // identifier, is introduced to rewrite the expression as a sum of product
197 // affine expression. Each local identifier is always and by construction a
198 // floordiv of a pure add/mul affine function of dimensional, symbolic, and
199 // other local identifiers, in a non-mutually recursive way. Hence, every local
200 // identifier can ultimately always be recovered as an affine function of
201 // dimensional and symbolic identifiers (involving floordiv's); note however
202 // that by AffineExpr construction, some floordiv combinations are converted to
203 // mod's. The result of the flattening is a flattened expression and a set of
204 // constraints involving just the local variables.
205 //
206 // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
207 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
208 //
209 // The simplification performed includes the accumulation of contributions for
210 // each dimensional and symbolic identifier together, the simplification of
211 // floordiv/ceildiv/mod expressions and other simplifications that in turn
212 // happen as a result. A simplification that this flattening naturally performs
213 // is of simplifying the numerator and denominator of floordiv/ceildiv, and
214 // folding a modulo expression to a zero, if possible. Three examples are below:
215 //
216 // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
217 // (d0 - d0 mod 4 + 4) mod 4 simplified to 0
218 // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
219 //
220 // The way the flattening works for the second example is as follows: d0 % 4 is
221 // replaced by d0 - 4*q with q being introduced: the expression then simplifies
222 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
223 // zero. Note that an affine expression may not always be expressible purely as
224 // a sum of products involving just the original dimensional and symbolic
225 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that
226 // may not be eliminated after simplification; in such cases, the final
227 // expression can be reconstructed by replacing the local identifiers with their
228 // corresponding explicit form stored in 'localExprs' (note that each of the
229 // explicit forms itself would have been simplified).
230 //
231 // The expression walk method here performs a linear time post order walk that
232 // performs the above simplifications through visit methods, with partial
233 // results being stored in 'operandExprStack'. When a parent expr is visited,
234 // the flattened expressions corresponding to its two operands would already be
235 // on the stack - the parent expression looks at the two flattened expressions
236 // and combines the two. It pops off the operand expressions and pushes the
237 // combined result (although this is done in-place on its LHS operand expr).
238 // When the walk is completed, the flattened form of the top-level expression
239 // would be left on the stack.
240 //
241 // A flattener can be repeatedly used for multiple affine expressions that bind
242 // to the same operands, for example, for all result expressions of an
243 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions
244 // is more efficient than creating a new flattener for each expression since
245 // common identical div and mod expressions appearing across different
246 // expressions are mapped to the same local identifier (same column position in
247 // 'localVarCst').
249  : public AffineExprVisitor<SimpleAffineExprFlattener> {
250 public:
251  // Flattend expression layout: [dims, symbols, locals, constant]
252  // Stack that holds the LHS and RHS operands while visiting a binary op expr.
253  // In future, consider adding a prepass to determine how big the SmallVector's
254  // will be, and linearize this to std::vector<int64_t> to prevent
255  // SmallVector moves on re-allocation.
256  std::vector<SmallVector<int64_t, 8>> operandExprStack;
257 
258  unsigned numDims;
259  unsigned numSymbols;
260 
261  // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
262  unsigned numLocals;
263 
264  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
265  // which new identifiers were introduced; if the latter do not get canceled
266  // out, these expressions can be readily used to reconstruct the AffineExpr
267  // (tree) form. Note that these expressions themselves would have been
268  // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
269  // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
270  // ceildiv 2 would be the local expression stored for q.
272 
273  SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
274 
275  virtual ~SimpleAffineExprFlattener() = default;
276 
277  // Visitor method overrides.
278  void visitMulExpr(AffineBinaryOpExpr expr);
279  void visitAddExpr(AffineBinaryOpExpr expr);
280  void visitDimExpr(AffineDimExpr expr);
285 
286  //
287  // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
288  //
289  // A mod expression "expr mod c" is thus flattened by introducing a new local
290  // variable q (= expr floordiv c), such that expr mod c is replaced with
291  // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
292  void visitModExpr(AffineBinaryOpExpr expr);
293 
294 protected:
295  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
296  // The local identifier added is always a floordiv of a pure add/mul affine
297  // function of other identifiers, coefficients of which are specified in
298  // dividend and with respect to a positive constant divisor. localExpr is the
299  // simplified tree expression (AffineExpr) corresponding to the quantifier.
300  virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
301  AffineExpr localExpr);
302 
303  /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
304  /// expr) when the rhs is a symbolic expression. The local identifier added
305  /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
306  /// function of other identifiers, coefficients of which are specified in the
307  /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
308  /// symbolic rhs expression. `localExpr` is the simplified tree expression
309  /// (AffineExpr) corresponding to the quantifier.
310  virtual void addLocalIdSemiAffine(AffineExpr localExpr);
311 
312 private:
313  /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
314  /// representing the affine expression corresponding to the quantifier
315  /// introduced as the local variable corresponding to `expr`. If the
316  /// quantifier is already present, we put the coefficient in the proper index
317  /// of `result`, otherwise we add a new local variable and put the coefficient
318  /// there.
319  void addLocalVariableSemiAffine(AffineExpr expr,
320  SmallVectorImpl<int64_t> &result,
321  unsigned long resultSize);
322 
323  // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
324  // A floordiv is thus flattened by introducing a new local variable q, and
325  // replacing that expression with 'q' while adding the constraints
326  // c * q <= expr <= c * q + c - 1 to localVarCst (done by
327  // IntegerRelation::addLocalFloorDiv).
328  //
329  // A ceildiv is similarly flattened:
330  // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
331  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
332 
333  int findLocalId(AffineExpr localExpr);
334 
335  inline unsigned getNumCols() const {
336  return numDims + numSymbols + numLocals + 1;
337  }
338  inline unsigned getConstantIndex() const { return getNumCols() - 1; }
339  inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
340  inline unsigned getSymbolStartIndex() const { return numDims; }
341  inline unsigned getDimStartIndex() const { return 0; }
342 };
343 
344 } // namespace mlir
345 
346 #endif // MLIR_IR_AFFINEEXPRVISITOR_H
Affine binary operation expression.
Definition: AffineExpr.h:207
AffineExpr getLHS() const
Definition: AffineExpr.cpp:317
AffineExpr getRHS() const
Definition: AffineExpr.cpp:320
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base class for AffineExpr visitors/walkers.
RetTy visitMulExpr(AffineBinaryOpExpr expr)
RetTy visitAddExpr(AffineBinaryOpExpr expr)
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)
RetTy walkPostOrder(AffineExpr expr)
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr)
RetTy visitSymbolExpr(AffineSymbolExpr expr)
RetTy visit(AffineExpr expr)
RetTy visitConstantExpr(AffineConstantExpr expr)
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr)
RetTy visitModExpr(AffineBinaryOpExpr expr)
RetTy visitDimExpr(AffineDimExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
U cast() const
Definition: AffineExpr.h:291
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
void visitFloorDivExpr(AffineBinaryOpExpr expr)
void visitAddExpr(AffineBinaryOpExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
void visitDimExpr(AffineDimExpr expr)
void visitConstantExpr(AffineConstantExpr expr)
virtual ~SimpleAffineExprFlattener()=default
void visitSymbolExpr(AffineSymbolExpr expr)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
SmallVector< AffineExpr, 4 > localExprs
void visitCeilDivExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
void visitMulExpr(AffineBinaryOpExpr expr)
This header declares functions that assist transformations in the MemRef dialect.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.