MLIR  18.0.0git
AffineExprVisitor.h
Go to the documentation of this file.
1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the AffineExpr visitor class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_AFFINEEXPRVISITOR_H
14 #define MLIR_IR_AFFINEEXPRVISITOR_H
15 
16 #include "mlir/IR/AffineExpr.h"
18 #include "llvm/ADT/ArrayRef.h"
19 
20 namespace mlir {
21 
22 /// Base class for AffineExpr visitors/walkers.
23 ///
24 /// AffineExpr visitors are used when you want to perform different actions
25 /// for different kinds of AffineExprs without having to use lots of casts
26 /// and a big switch instruction.
27 ///
28 /// To define your own visitor, inherit from this class, specifying your
29 /// new type for the 'SubClass' template parameter, and "override" visitXXX
30 /// functions in your class. This class is defined in terms of statically
31 /// resolved overloading, not virtual functions.
32 ///
33 /// For example, here is a visitor that counts the number of for AffineDimExprs
34 /// in an AffineExpr.
35 ///
36 /// /// Declare the class. Note that we derive from AffineExprVisitor
37 /// /// instantiated with our new subclasses_ type.
38 ///
39 /// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
40 /// unsigned numDimExprs;
41 /// DimExprCounter() : numDimExprs(0) {}
42 /// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
43 /// };
44 ///
45 /// And this class would be used like this:
46 /// DimExprCounter dec;
47 /// dec.visit(affineExpr);
48 /// numDimExprs = dec.numDimExprs;
49 ///
50 /// AffineExprVisitor provides visit methods for the following binary affine
51 /// op expressions:
52 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
53 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
54 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these
55 /// methods will call the general AffineBinaryOpExpr method.
56 ///
57 /// In addition, visit methods are provided for the following affine
58 // expressions: AffineConstantExpr, AffineDimExpr, and
59 // AffineSymbolExpr.
60 ///
61 /// Note that if you don't implement visitXXX for some affine expression type,
62 /// the visitXXX method for Instruction superclass will be invoked.
63 ///
64 /// Note that this class is specifically designed as a template to avoid
65 /// virtual function call overhead. Defining and using a AffineExprVisitor is
66 /// just as efficient as having your own switch instruction over the instruction
67 /// opcode.
68 
69 template <typename SubClass, typename RetTy>
71 public:
72  // Function to visit an AffineExpr.
73  RetTy visit(AffineExpr expr) {
74  static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
75  "Must instantiate with a derived type of AffineExprVisitor");
76  auto self = static_cast<SubClass *>(this);
77  switch (expr.getKind()) {
78  case AffineExprKind::Add: {
79  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
80  return self->visitAddExpr(binOpExpr);
81  }
82  case AffineExprKind::Mul: {
83  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
84  return self->visitMulExpr(binOpExpr);
85  }
86  case AffineExprKind::Mod: {
87  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
88  return self->visitModExpr(binOpExpr);
89  }
91  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
92  return self->visitFloorDivExpr(binOpExpr);
93  }
95  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
96  return self->visitCeilDivExpr(binOpExpr);
97  }
99  return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
101  return self->visitDimExpr(cast<AffineDimExpr>(expr));
103  return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
104  }
105  llvm_unreachable("Unknown AffineExpr");
106  }
107 
108  //===--------------------------------------------------------------------===//
109  // Visitation functions... these functions provide default fallbacks in case
110  // the user does not specify what to do for a particular instruction type.
111  // The default behavior is to generalize the instruction type to its subtype
112  // and try visiting the subtype. All of this should be inlined perfectly,
113  // because there are no virtual functions to get in the way.
114  //
115 
116  // Default visit methods. Note that the default op-specific binary op visit
117  // methods call the general visitAffineBinaryOpExpr visit method.
118  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
120  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
121  }
123  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
124  }
126  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
127  }
129  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
130  }
132  return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
133  }
134  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
135  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
136  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
137 };
138 
139 template <typename SubClass, typename RetTy = void>
140 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
141  //===--------------------------------------------------------------------===//
142  // Interface code - This is the public interface of the AffineExprVisitor
143  // that you use to visit affine expressions...
144 public:
145  // Function to walk an AffineExpr (in post order).
146  RetTy walkPostOrder(AffineExpr expr) {
147  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
148  "Must instantiate with a derived type of AffineExprVisitor");
149  auto self = static_cast<SubClass *>(this);
150  switch (expr.getKind()) {
151  case AffineExprKind::Add: {
152  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
153  walkOperandsPostOrder(binOpExpr);
154  return self->visitAddExpr(binOpExpr);
155  }
156  case AffineExprKind::Mul: {
157  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
158  walkOperandsPostOrder(binOpExpr);
159  return self->visitMulExpr(binOpExpr);
160  }
161  case AffineExprKind::Mod: {
162  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
163  walkOperandsPostOrder(binOpExpr);
164  return self->visitModExpr(binOpExpr);
165  }
167  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
168  walkOperandsPostOrder(binOpExpr);
169  return self->visitFloorDivExpr(binOpExpr);
170  }
172  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
173  walkOperandsPostOrder(binOpExpr);
174  return self->visitCeilDivExpr(binOpExpr);
175  }
177  return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
179  return self->visitDimExpr(cast<AffineDimExpr>(expr));
181  return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
182  }
183  llvm_unreachable("Unknown AffineExpr");
184  }
185 
186 private:
187  // Walk the operands - each operand is itself walked in post order.
188  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
189  walkPostOrder(expr.getLHS());
190  walkPostOrder(expr.getRHS());
191  }
192 };
193 
194 template <typename SubClass>
196  : public AffineExprVisitorBase<SubClass, LogicalResult> {
197  //===--------------------------------------------------------------------===//
198  // Interface code - This is the public interface of the AffineExprVisitor
199  // that you use to visit affine expressions...
200 public:
201  // Function to walk an AffineExpr (in post order).
203  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
204  "Must instantiate with a derived type of AffineExprVisitor");
205  auto self = static_cast<SubClass *>(this);
206  switch (expr.getKind()) {
207  case AffineExprKind::Add: {
208  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
209  if (failed(walkOperandsPostOrder(binOpExpr)))
210  return failure();
211  return self->visitAddExpr(binOpExpr);
212  }
213  case AffineExprKind::Mul: {
214  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
215  if (failed(walkOperandsPostOrder(binOpExpr)))
216  return failure();
217  return self->visitMulExpr(binOpExpr);
218  }
219  case AffineExprKind::Mod: {
220  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
221  if (failed(walkOperandsPostOrder(binOpExpr)))
222  return failure();
223  return self->visitModExpr(binOpExpr);
224  }
226  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
227  if (failed(walkOperandsPostOrder(binOpExpr)))
228  return failure();
229  return self->visitFloorDivExpr(binOpExpr);
230  }
232  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
233  if (failed(walkOperandsPostOrder(binOpExpr)))
234  return failure();
235  return self->visitCeilDivExpr(binOpExpr);
236  }
238  return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
240  return self->visitDimExpr(cast<AffineDimExpr>(expr));
242  return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
243  }
244  llvm_unreachable("Unknown AffineExpr");
245  }
246 
247 private:
248  // Walk the operands - each operand is itself walked in post order.
249  LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
250  if (failed(walkPostOrder(expr.getLHS())))
251  return failure();
252  if (failed(walkPostOrder(expr.getRHS())))
253  return failure();
254  return success();
255  }
256 };
257 
258 // This class is used to flatten a pure affine expression (AffineExpr,
259 // which is in a tree form) into a sum of products (w.r.t constants) when
260 // possible, and in that process simplifying the expression. For a modulo,
261 // floordiv, or a ceildiv expression, an additional identifier, called a local
262 // identifier, is introduced to rewrite the expression as a sum of product
263 // affine expression. Each local identifier is always and by construction a
264 // floordiv of a pure add/mul affine function of dimensional, symbolic, and
265 // other local identifiers, in a non-mutually recursive way. Hence, every local
266 // identifier can ultimately always be recovered as an affine function of
267 // dimensional and symbolic identifiers (involving floordiv's); note however
268 // that by AffineExpr construction, some floordiv combinations are converted to
269 // mod's. The result of the flattening is a flattened expression and a set of
270 // constraints involving just the local variables.
271 //
272 // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
273 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
274 //
275 // The simplification performed includes the accumulation of contributions for
276 // each dimensional and symbolic identifier together, the simplification of
277 // floordiv/ceildiv/mod expressions and other simplifications that in turn
278 // happen as a result. A simplification that this flattening naturally performs
279 // is of simplifying the numerator and denominator of floordiv/ceildiv, and
280 // folding a modulo expression to a zero, if possible. Three examples are below:
281 //
282 // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
283 // (d0 - d0 mod 4 + 4) mod 4 simplified to 0
284 // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
285 //
286 // The way the flattening works for the second example is as follows: d0 % 4 is
287 // replaced by d0 - 4*q with q being introduced: the expression then simplifies
288 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
289 // zero. Note that an affine expression may not always be expressible purely as
290 // a sum of products involving just the original dimensional and symbolic
291 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that
292 // may not be eliminated after simplification; in such cases, the final
293 // expression can be reconstructed by replacing the local identifiers with their
294 // corresponding explicit form stored in 'localExprs' (note that each of the
295 // explicit forms itself would have been simplified).
296 //
297 // The expression walk method here performs a linear time post order walk that
298 // performs the above simplifications through visit methods, with partial
299 // results being stored in 'operandExprStack'. When a parent expr is visited,
300 // the flattened expressions corresponding to its two operands would already be
301 // on the stack - the parent expression looks at the two flattened expressions
302 // and combines the two. It pops off the operand expressions and pushes the
303 // combined result (although this is done in-place on its LHS operand expr).
304 // When the walk is completed, the flattened form of the top-level expression
305 // would be left on the stack.
306 //
307 // A flattener can be repeatedly used for multiple affine expressions that bind
308 // to the same operands, for example, for all result expressions of an
309 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions
310 // is more efficient than creating a new flattener for each expression since
311 // common identical div and mod expressions appearing across different
312 // expressions are mapped to the same local identifier (same column position in
313 // 'localVarCst').
315  : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
316 public:
317  // Flattend expression layout: [dims, symbols, locals, constant]
318  // Stack that holds the LHS and RHS operands while visiting a binary op expr.
319  // In future, consider adding a prepass to determine how big the SmallVector's
320  // will be, and linearize this to std::vector<int64_t> to prevent
321  // SmallVector moves on re-allocation.
322  std::vector<SmallVector<int64_t, 8>> operandExprStack;
323 
324  unsigned numDims;
325  unsigned numSymbols;
326 
327  // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
328  unsigned numLocals;
329 
330  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
331  // which new identifiers were introduced; if the latter do not get canceled
332  // out, these expressions can be readily used to reconstruct the AffineExpr
333  // (tree) form. Note that these expressions themselves would have been
334  // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
335  // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
336  // ceildiv 2 would be the local expression stored for q.
338 
339  SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
340 
341  virtual ~SimpleAffineExprFlattener() = default;
342 
343  // Visitor method overrides.
351 
352  //
353  // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
354  //
355  // A mod expression "expr mod c" is thus flattened by introducing a new local
356  // variable q (= expr floordiv c), such that expr mod c is replaced with
357  // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
359 
360 protected:
361  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
362  // The local identifier added is always a floordiv of a pure add/mul affine
363  // function of other identifiers, coefficients of which are specified in
364  // dividend and with respect to a positive constant divisor. localExpr is the
365  // simplified tree expression (AffineExpr) corresponding to the quantifier.
366  virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
367  AffineExpr localExpr);
368 
369  /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
370  /// expr) when the rhs is a symbolic expression. The local identifier added
371  /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
372  /// function of other identifiers, coefficients of which are specified in the
373  /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
374  /// symbolic rhs expression. `localExpr` is the simplified tree expression
375  /// (AffineExpr) corresponding to the quantifier.
376  virtual void addLocalIdSemiAffine(AffineExpr localExpr);
377 
378 private:
379  /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
380  /// representing the affine expression corresponding to the quantifier
381  /// introduced as the local variable corresponding to `expr`. If the
382  /// quantifier is already present, we put the coefficient in the proper index
383  /// of `result`, otherwise we add a new local variable and put the coefficient
384  /// there.
385  void addLocalVariableSemiAffine(AffineExpr expr,
386  SmallVectorImpl<int64_t> &result,
387  unsigned long resultSize);
388 
389  // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
390  // A floordiv is thus flattened by introducing a new local variable q, and
391  // replacing that expression with 'q' while adding the constraints
392  // c * q <= expr <= c * q + c - 1 to localVarCst (done by
393  // IntegerRelation::addLocalFloorDiv).
394  //
395  // A ceildiv is similarly flattened:
396  // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
397  LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
398 
399  int findLocalId(AffineExpr localExpr);
400 
401  inline unsigned getNumCols() const {
402  return numDims + numSymbols + numLocals + 1;
403  }
404  inline unsigned getConstantIndex() const { return getNumCols() - 1; }
405  inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
406  inline unsigned getSymbolStartIndex() const { return numDims; }
407  inline unsigned getDimStartIndex() const { return 0; }
408 };
409 
410 } // namespace mlir
411 
412 #endif // MLIR_IR_AFFINEEXPRVISITOR_H
Affine binary operation expression.
Definition: AffineExpr.h:213
AffineExpr getLHS() const
Definition: AffineExpr.cpp:317
AffineExpr getRHS() const
Definition: AffineExpr.cpp:320
An integer constant appearing in affine expression.
Definition: AffineExpr.h:238
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:222
Base class for AffineExpr visitors/walkers.
RetTy visitAddExpr(AffineBinaryOpExpr expr)
RetTy visitDimExpr(AffineDimExpr expr)
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr)
RetTy visit(AffineExpr expr)
RetTy visitSymbolExpr(AffineSymbolExpr expr)
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)
RetTy visitModExpr(AffineBinaryOpExpr expr)
RetTy visitConstantExpr(AffineConstantExpr expr)
RetTy visitMulExpr(AffineBinaryOpExpr expr)
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr)
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:230
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
virtual ~SimpleAffineExprFlattener()=default
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26