MLIR  20.0.0git
AffineExprVisitor.h
Go to the documentation of this file.
1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the AffineExpr visitor class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_AFFINEEXPRVISITOR_H
14 #define MLIR_IR_AFFINEEXPRVISITOR_H
15 
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/ArrayRef.h"
19 
20 namespace mlir {
21 
22 /// Base class for AffineExpr visitors/walkers.
23 ///
24 /// AffineExpr visitors are used when you want to perform different actions
25 /// for different kinds of AffineExprs without having to use lots of casts
26 /// and a big switch instruction.
27 ///
28 /// To define your own visitor, inherit from this class, specifying your
29 /// new type for the 'SubClass' template parameter, and "override" visitXXX
30 /// functions in your class. This class is defined in terms of statically
31 /// resolved overloading, not virtual functions.
32 ///
33 /// The visitor is templated on its return type (`RetTy`). With a WalkResult
34 /// return type, the visitor supports interrupting walks.
35 ///
36 /// For example, here is a visitor that counts the number of for AffineDimExprs
37 /// in an AffineExpr.
38 ///
39 /// /// Declare the class. Note that we derive from AffineExprVisitor
40 /// /// instantiated with our new subclasses_ type.
41 ///
42 /// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
43 /// unsigned numDimExprs;
44 /// DimExprCounter() : numDimExprs(0) {}
45 /// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
46 /// };
47 ///
48 /// And this class would be used like this:
49 /// DimExprCounter dec;
50 /// dec.visit(affineExpr);
51 /// numDimExprs = dec.numDimExprs;
52 ///
53 /// AffineExprVisitor provides visit methods for the following binary affine
54 /// op expressions:
55 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
56 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
57 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these
58 /// methods will call the general AffineBinaryOpExpr method.
59 ///
60 /// In addition, visit methods are provided for the following affine
61 // expressions: AffineConstantExpr, AffineDimExpr, and
62 // AffineSymbolExpr.
63 ///
64 /// Note that if you don't implement visitXXX for some affine expression type,
65 /// the visitXXX method for Instruction superclass will be invoked.
66 ///
67 /// Note that this class is specifically designed as a template to avoid
68 /// virtual function call overhead. Defining and using a AffineExprVisitor is
69 /// just as efficient as having your own switch instruction over the instruction
70 /// opcode.
71 template <typename SubClass, typename RetTy>
73 public:
74  // Function to visit an AffineExpr.
75  RetTy visit(AffineExpr expr) {
76  static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
77  "Must instantiate with a derived type of AffineExprVisitor");
78  auto self = static_cast<SubClass *>(this);
79  switch (expr.getKind()) {
80  case AffineExprKind::Add: {
81  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
82  return self->visitAddExpr(binOpExpr);
83  }
84  case AffineExprKind::Mul: {
85  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
86  return self->visitMulExpr(binOpExpr);
87  }
88  case AffineExprKind::Mod: {
89  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
90  return self->visitModExpr(binOpExpr);
91  }
93  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
94  return self->visitFloorDivExpr(binOpExpr);
95  }
97  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
98  return self->visitCeilDivExpr(binOpExpr);
99  }
101  return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
103  return self->visitDimExpr(cast<AffineDimExpr>(expr));
105  return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
106  }
107  llvm_unreachable("Unknown AffineExpr");
108  }
109 
110  //===--------------------------------------------------------------------===//
111  // Visitation functions... these functions provide default fallbacks in case
112  // the user does not specify what to do for a particular instruction type.
113  // The default behavior is to generalize the instruction type to its subtype
114  // and try visiting the subtype. All of this should be inlined perfectly,
115  // because there are no virtual functions to get in the way.
116  //
117 
118  // Default visit methods. Note that the default op-specific binary op visit
119  // methods call the general visitAffineBinaryOpExpr visit method.
120  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
122  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
123  }
125  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
126  }
128  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
129  }
131  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
132  }
134  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
135  }
136  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
137  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
138  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
139 };
140 
141 /// See documentation for AffineExprVisitorBase. This visitor supports
142 /// interrupting walks when a `WalkResult` is used for `RetTy`.
143 template <typename SubClass, typename RetTy = void>
144 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
145  //===--------------------------------------------------------------------===//
146  // Interface code - This is the public interface of the AffineExprVisitor
147  // that you use to visit affine expressions...
148 public:
149  // Function to walk an AffineExpr (in post order).
150  RetTy walkPostOrder(AffineExpr expr) {
151  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
152  "Must instantiate with a derived type of AffineExprVisitor");
153  auto self = static_cast<SubClass *>(this);
154  switch (expr.getKind()) {
155  case AffineExprKind::Add: {
156  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
157  if constexpr (std::is_same<RetTy, WalkResult>::value) {
158  if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
159  return WalkResult::interrupt();
160  } else {
161  walkOperandsPostOrder(binOpExpr);
162  }
163  return self->visitAddExpr(binOpExpr);
164  }
165  case AffineExprKind::Mul: {
166  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
167  if constexpr (std::is_same<RetTy, WalkResult>::value) {
168  if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
169  return WalkResult::interrupt();
170  } else {
171  walkOperandsPostOrder(binOpExpr);
172  }
173  return self->visitMulExpr(binOpExpr);
174  }
175  case AffineExprKind::Mod: {
176  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
177  if constexpr (std::is_same<RetTy, WalkResult>::value) {
178  if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
179  return WalkResult::interrupt();
180  } else {
181  walkOperandsPostOrder(binOpExpr);
182  }
183  return self->visitModExpr(binOpExpr);
184  }
186  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
187  if constexpr (std::is_same<RetTy, WalkResult>::value) {
188  if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
189  return WalkResult::interrupt();
190  } else {
191  walkOperandsPostOrder(binOpExpr);
192  }
193  return self->visitFloorDivExpr(binOpExpr);
194  }
196  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
197  if constexpr (std::is_same<RetTy, WalkResult>::value) {
198  if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
199  return WalkResult::interrupt();
200  } else {
201  walkOperandsPostOrder(binOpExpr);
202  }
203  return self->visitCeilDivExpr(binOpExpr);
204  }
206  return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
208  return self->visitDimExpr(cast<AffineDimExpr>(expr));
210  return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
211  }
212  llvm_unreachable("Unknown AffineExpr");
213  }
214 
215 private:
216  // Walk the operands - each operand is itself walked in post order.
217  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
218  if constexpr (std::is_same<RetTy, WalkResult>::value) {
219  if (walkPostOrder(expr.getLHS()).wasInterrupted())
220  return WalkResult::interrupt();
221  } else {
222  walkPostOrder(expr.getLHS());
223  }
224  if constexpr (std::is_same<RetTy, WalkResult>::value) {
225  if (walkPostOrder(expr.getRHS()).wasInterrupted())
226  return WalkResult::interrupt();
227  return WalkResult::advance();
228  } else {
229  return walkPostOrder(expr.getRHS());
230  }
231  }
232 };
233 
234 template <typename SubClass>
235 class AffineExprVisitor<SubClass, LogicalResult>
236  : public AffineExprVisitorBase<SubClass, LogicalResult> {
237  //===--------------------------------------------------------------------===//
238  // Interface code - This is the public interface of the AffineExprVisitor
239  // that you use to visit affine expressions...
240 public:
241  // Function to walk an AffineExpr (in post order).
242  LogicalResult walkPostOrder(AffineExpr expr) {
243  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
244  "Must instantiate with a derived type of AffineExprVisitor");
245  auto self = static_cast<SubClass *>(this);
246  switch (expr.getKind()) {
247  case AffineExprKind::Add: {
248  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
249  if (failed(walkOperandsPostOrder(binOpExpr)))
250  return failure();
251  return self->visitAddExpr(binOpExpr);
252  }
253  case AffineExprKind::Mul: {
254  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
255  if (failed(walkOperandsPostOrder(binOpExpr)))
256  return failure();
257  return self->visitMulExpr(binOpExpr);
258  }
259  case AffineExprKind::Mod: {
260  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
261  if (failed(walkOperandsPostOrder(binOpExpr)))
262  return failure();
263  return self->visitModExpr(binOpExpr);
264  }
266  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
267  if (failed(walkOperandsPostOrder(binOpExpr)))
268  return failure();
269  return self->visitFloorDivExpr(binOpExpr);
270  }
272  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
273  if (failed(walkOperandsPostOrder(binOpExpr)))
274  return failure();
275  return self->visitCeilDivExpr(binOpExpr);
276  }
278  return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
280  return self->visitDimExpr(cast<AffineDimExpr>(expr));
282  return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
283  }
284  llvm_unreachable("Unknown AffineExpr");
285  }
286 
287 private:
288  // Walk the operands - each operand is itself walked in post order.
289  LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
290  if (failed(walkPostOrder(expr.getLHS())))
291  return failure();
292  if (failed(walkPostOrder(expr.getRHS())))
293  return failure();
294  return success();
295  }
296 };
297 
298 // This class is used to flatten a pure affine expression (AffineExpr,
299 // which is in a tree form) into a sum of products (w.r.t constants) when
300 // possible, and in that process simplifying the expression. For a modulo,
301 // floordiv, or a ceildiv expression, an additional identifier, called a local
302 // identifier, is introduced to rewrite the expression as a sum of product
303 // affine expression. Each local identifier is always and by construction a
304 // floordiv of a pure add/mul affine function of dimensional, symbolic, and
305 // other local identifiers, in a non-mutually recursive way. Hence, every local
306 // identifier can ultimately always be recovered as an affine function of
307 // dimensional and symbolic identifiers (involving floordiv's); note however
308 // that by AffineExpr construction, some floordiv combinations are converted to
309 // mod's. The result of the flattening is a flattened expression and a set of
310 // constraints involving just the local variables.
311 //
312 // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
313 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
314 //
315 // The simplification performed includes the accumulation of contributions for
316 // each dimensional and symbolic identifier together, the simplification of
317 // floordiv/ceildiv/mod expressions and other simplifications that in turn
318 // happen as a result. A simplification that this flattening naturally performs
319 // is of simplifying the numerator and denominator of floordiv/ceildiv, and
320 // folding a modulo expression to a zero, if possible. Three examples are below:
321 //
322 // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
323 // (d0 - d0 mod 4 + 4) mod 4 simplified to 0
324 // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
325 //
326 // The way the flattening works for the second example is as follows: d0 % 4 is
327 // replaced by d0 - 4*q with q being introduced: the expression then simplifies
328 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
329 // zero. Note that an affine expression may not always be expressible purely as
330 // a sum of products involving just the original dimensional and symbolic
331 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that
332 // may not be eliminated after simplification; in such cases, the final
333 // expression can be reconstructed by replacing the local identifiers with their
334 // corresponding explicit form stored in 'localExprs' (note that each of the
335 // explicit forms itself would have been simplified).
336 //
337 // The expression walk method here performs a linear time post order walk that
338 // performs the above simplifications through visit methods, with partial
339 // results being stored in 'operandExprStack'. When a parent expr is visited,
340 // the flattened expressions corresponding to its two operands would already be
341 // on the stack - the parent expression looks at the two flattened expressions
342 // and combines the two. It pops off the operand expressions and pushes the
343 // combined result (although this is done in-place on its LHS operand expr).
344 // When the walk is completed, the flattened form of the top-level expression
345 // would be left on the stack.
346 //
347 // A flattener can be repeatedly used for multiple affine expressions that bind
348 // to the same operands, for example, for all result expressions of an
349 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions
350 // is more efficient than creating a new flattener for each expression since
351 // common identical div and mod expressions appearing across different
352 // expressions are mapped to the same local identifier (same column position in
353 // 'localVarCst').
355  : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
356 public:
357  // Flattend expression layout: [dims, symbols, locals, constant]
358  // Stack that holds the LHS and RHS operands while visiting a binary op expr.
359  // In future, consider adding a prepass to determine how big the SmallVector's
360  // will be, and linearize this to std::vector<int64_t> to prevent
361  // SmallVector moves on re-allocation.
362  std::vector<SmallVector<int64_t, 8>> operandExprStack;
363 
364  unsigned numDims;
365  unsigned numSymbols;
366 
367  // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
368  unsigned numLocals;
369 
370  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
371  // which new identifiers were introduced; if the latter do not get canceled
372  // out, these expressions can be readily used to reconstruct the AffineExpr
373  // (tree) form. Note that these expressions themselves would have been
374  // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
375  // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
376  // ceildiv 2 would be the local expression stored for q.
378 
379  SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
380 
381  virtual ~SimpleAffineExprFlattener() = default;
382 
383  // Visitor method overrides.
384  LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
385  LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
386  LogicalResult visitDimExpr(AffineDimExpr expr);
387  LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
388  LogicalResult visitConstantExpr(AffineConstantExpr expr);
389  LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
390  LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
391 
392  //
393  // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
394  //
395  // A mod expression "expr mod c" is thus flattened by introducing a new local
396  // variable q (= expr floordiv c), such that expr mod c is replaced with
397  // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
398  LogicalResult visitModExpr(AffineBinaryOpExpr expr);
399 
400 protected:
401  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
402  // The local identifier added is always a floordiv of a pure add/mul affine
403  // function of other identifiers, coefficients of which are specified in
404  // dividend and with respect to a positive constant divisor. localExpr is the
405  // simplified tree expression (AffineExpr) corresponding to the quantifier.
406  virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
407  AffineExpr localExpr);
408 
409  /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
410  /// expr) when the rhs is a symbolic expression. The local identifier added
411  /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
412  /// function of other identifiers, coefficients of which are specified in the
413  /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
414  /// symbolic rhs expression. `localExpr` is the simplified tree expression
415  /// (AffineExpr) corresponding to the quantifier.
416  virtual LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
417  ArrayRef<int64_t> rhs,
418  AffineExpr localExpr);
419 
420 private:
421  /// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression
422  /// representing the affine expression corresponding to the quantifier
423  /// introduced as the local variable corresponding to `localExpr`. If the
424  /// quantifier is already present, we put the coefficient in the proper index
425  /// of `result`, otherwise we add a new local variable and put the coefficient
426  /// there.
427  LogicalResult addLocalVariableSemiAffine(ArrayRef<int64_t> lhs,
428  ArrayRef<int64_t> rhs,
429  AffineExpr localExpr,
430  SmallVectorImpl<int64_t> &result,
431  unsigned long resultSize);
432 
433  // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
434  // A floordiv is thus flattened by introducing a new local variable q, and
435  // replacing that expression with 'q' while adding the constraints
436  // c * q <= expr <= c * q + c - 1 to localVarCst (done by
437  // IntegerRelation::addLocalFloorDiv).
438  //
439  // A ceildiv is similarly flattened:
440  // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
441  LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
442 
443  int findLocalId(AffineExpr localExpr);
444 
445  inline unsigned getNumCols() const {
446  return numDims + numSymbols + numLocals + 1;
447  }
448  inline unsigned getConstantIndex() const { return getNumCols() - 1; }
449  inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
450  inline unsigned getSymbolStartIndex() const { return numDims; }
451  inline unsigned getDimStartIndex() const { return 0; }
452 };
453 
454 } // namespace mlir
455 
456 #endif // MLIR_IR_AFFINEEXPRVISITOR_H
Affine binary operation expression.
Definition: AffineExpr.h:227
AffineExpr getLHS() const
Definition: AffineExpr.cpp:340
AffineExpr getRHS() const
Definition: AffineExpr.cpp:343
An integer constant appearing in affine expression.
Definition: AffineExpr.h:252
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
Base class for AffineExpr visitors/walkers.
RetTy visitAddExpr(AffineBinaryOpExpr expr)
RetTy visitDimExpr(AffineDimExpr expr)
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr)
RetTy visit(AffineExpr expr)
RetTy visitSymbolExpr(AffineSymbolExpr expr)
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)
RetTy visitModExpr(AffineBinaryOpExpr expr)
RetTy visitConstantExpr(AffineConstantExpr expr)
RetTy visitMulExpr(AffineBinaryOpExpr expr)
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr)
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:244
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
virtual LogicalResult addLocalIdSemiAffine(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
virtual ~SimpleAffineExprFlattener()=default
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Include the generated interface declarations.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.