MLIR  15.0.0git
Merger.h
Go to the documentation of this file.
1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with iteration lattices.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15 
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/BitVector.h"
19 
20 namespace mlir {
21 namespace sparse_tensor {
22 
23 /// Dimension level type for a tensor (undef means index does not appear).
25 
26 /// Tensor expression kind.
27 enum Kind {
28  // Leaf.
29  kTensor = 0,
32  // Unary operations.
49  kCastFS, // signed
50  kCastFU, // unsigned
51  kCastSF, // signed
52  kCastUF, // unsigned
53  kCastS, // signed
54  kCastU, // unsigned
57  kCIm, // complex.im
58  kCRe, // complex.re
60  kBinaryBranch, // semiring unary branch created from a binary op
61  kUnary, // semiring unary op
62  // Binary operations.
67  kDivC, // complex
68  kDivS, // signed
69  kDivU, // unsigned
79  kShrS, // signed
80  kShrU, // unsigned
82  kBinary, // semiring binary op
83 };
84 
85 /// Children subexpressions of tensor operations.
86 struct Children {
87  unsigned e0;
88  unsigned e1;
89 };
90 
91 /// Tensor expression. Represents a MLIR expression in tensor index notation.
92 struct TensorExp {
93  TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation);
94 
95  /// Tensor expression kind.
97 
98  union {
99  /// Expressions representing tensors simply have a tensor number.
100  unsigned tensor;
101 
102  /// Indices hold the index number.
103  unsigned index;
104 
105  /// Tensor operations hold the indices of their children.
107  };
108 
109  /// Direct link to IR for an invariant or the destination value (to
110  /// infer destination type) of a cast operation During code generation,
111  /// this field may be used to cache "hoisted" loop invariant tensor loads.
113 
114  /// Code blocks used by semirings. For the case of kUnary and
115  /// kBinary, this holds the original operation with all regions. For
116  /// kBinaryBranch, this holds the YieldOp for the left or right half
117  /// to be merged into a nested scf loop.
119 };
120 
121 /// Lattice point. Each lattice point consists of a conjunction of tensor
122 /// loop indices (encoded in a bitvector) and the index of the corresponding
123 /// tensor expression.
124 struct LatPoint {
125  LatPoint(unsigned n, unsigned e, unsigned b);
126  LatPoint(const BitVector &b, unsigned e);
127 
128  /// Conjunction of tensor loop indices as bitvector. This represents
129  /// all indices involved in the tensor expression
130  BitVector bits;
131 
132  /// Simplified conjunction of tensor loop indices as bitvector. This
133  /// represents a simplified condition under which this tensor expression
134  /// must execute. Pre-computed during codegen to avoid repeated eval.
135  BitVector simple;
136 
137  /// Index of the tensor expression.
138  unsigned exp;
139 };
140 
141 /// A class to handle all iteration lattice operations. This class abstracts
142 /// away from some implementation details of storing iteration lattices and
143 /// tensor expressions. This allows for fine-tuning performance characteristics
144 /// independently from the basic algorithm if bottlenecks are identified.
145 class Merger {
146 public:
147  /// Constructs a merger for the given number of tensors and loops. The
148  /// user supplies the number of tensors involved in the kernel, with the
149  /// last tensor in this set denoting the output tensor. The merger adds an
150  /// additional synthetic tensor at the end of this set to represent all
151  /// invariant expressions in the kernel.
152  Merger(unsigned t, unsigned l)
153  : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
154  hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
155 
156  /// Adds a tensor expression. Returns its index.
157  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
158  Operation *op = nullptr);
159  unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) {
160  return addExp(k, e, -1u, v, op);
161  }
162  unsigned addExp(Kind k, Value v, Operation *op = nullptr) {
163  return addExp(k, -1u, -1u, v, op);
164  }
165 
166  /// Adds an iteration lattice point. Returns its index.
167  unsigned addLat(unsigned t, unsigned i, unsigned e);
168 
169  /// Adds a new, initially empty, set. Returns its index.
170  unsigned addSet();
171 
172  /// Computes a single conjunction of two lattice points by taking the "union"
173  /// of loop indices (effectively constructing a larger "intersection" of those
174  /// indices) with a newly constructed tensor (sub)expression of given kind.
175  /// Returns the index of the new lattice point.
176  unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1,
177  Operation *op = nullptr);
178 
179  /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
180  /// cartesian product. Returns the index of the new set.
181  unsigned takeConj(Kind kind, unsigned s0, unsigned s1,
182  Operation *op = nullptr);
183 
184  /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
185  /// Returns the index of the new set.
186  unsigned takeDisj(Kind kind, unsigned s0, unsigned s1,
187  Operation *op = nullptr);
188 
189  /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of
190  /// the overlap, left, and right regions. Any region may be left missing in
191  /// the output. Returns the index of the new set.
192  unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
193  bool includeLeft, Kind ltrans, Operation *opleft,
194  bool includeRight, Kind rtrans, Operation *opright);
195 
196  /// Maps the unary operator over the lattice set of the operand, i.e. each
197  /// lattice point on an expression E is simply copied over, but with OP E
198  /// as new expression. Returns the index of the new set.
199  unsigned mapSet(Kind kind, unsigned s0, Value v = Value(),
200  Operation *op = nullptr);
201 
202  /// Optimizes the iteration lattice points in the given set. This
203  /// method should be called right before code generation to avoid
204  /// generating redundant loops and conditions.
205  unsigned optimizeSet(unsigned s0);
206 
207  /// Simplifies the conditions in a conjunction of a given lattice point
208  /// within the given set using just two basic rules:
209  /// (1) multiple dense conditions are reduced to single dense, and
210  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
211  BitVector simplifyCond(unsigned s0, unsigned p0);
212 
213  /// Returns true if Li > Lj.
214  bool latGT(unsigned i, unsigned j) const;
215 
216  /// Returns true if Li and Lj only differ in dense.
217  bool onlyDenseDiff(unsigned i, unsigned j);
218 
219  /// Bit translation.
220  unsigned tensor(unsigned b) const { return b % numTensors; }
221  unsigned index(unsigned b) const { return b / numTensors; }
222 
223  /// Returns true if bit corresponds to queried dim.
224  bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
225 
226  /// Returns true if bit corresponds to index of output tensor.
227  bool isOutTensor(unsigned b, unsigned i) const {
228  return tensor(b) == outTensor && index(b) == i;
229  }
230 
231  /// Returns true if tensor access at given index has queried dim.
232  bool isDim(unsigned t, unsigned i, Dim d) const {
233  assert(t < numTensors && i < numLoops);
234  return dims[t][i] == d;
235  }
236 
237  /// Returns true if any set bit corresponds to queried dim.
238  bool hasAnyDimOf(const BitVector &bits, Dim d) const;
239 
240  /// Returns true if given tensor iterates *only* in the given tensor
241  /// expression. For the output tensor, this defines a "simply dynamic"
242  /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
243  /// sparse vector a.
244  bool isSingleCondition(unsigned t, unsigned e) const;
245 
246  /// Dimension setter.
247  void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
248 
249  // Has sparse output tensor setter.
250  void setHasSparseOut(bool s) { hasSparseOut = s; }
251 
252  /// Convenience getters to immediately access the stored nodes.
253  /// Typically it is inadvisible to keep the reference around, as in
254  /// "TensorExpr &te = merger.exp(e))", since insertions into the merger
255  /// may cause data movement and invalidate the underlying memory address.
256  TensorExp &exp(unsigned e) { return tensorExps[e]; }
257  LatPoint &lat(unsigned l) { return latPoints[l]; }
258  SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
259 
260 #ifndef NDEBUG
261  /// Print methods (for debugging).
262  void dumpExp(unsigned e) const;
263  void dumpLat(unsigned p) const;
264  void dumpSet(unsigned s) const;
265  void dumpBits(const BitVector &bits) const;
266 #endif
267 
268  /// Builds the iteration lattices in a bottom-up traversal given the remaining
269  /// tensor (sub)expression and the next loop index in the iteration graph.
270  /// Returns index of the root expression.
271  unsigned buildLattices(unsigned e, unsigned i);
272 
273  /// Builds a tensor expression from the given Linalg operation.
274  /// Returns index of the root expression on success.
275  Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
276 
277  /// Rebuilds SSA format from a tensor expression.
278  Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0,
279  Value v1);
280 
281 private:
282  /// Private helpers.
283  bool maybeZero(unsigned e) const;
284  bool isInvariant(unsigned e) const;
285  Type inferType(unsigned e, Value src);
286 
287  /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
288  Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
289 
290  /// Merger data structures.
291  const unsigned outTensor;
292  const unsigned syntheticTensor;
293  const unsigned numTensors;
294  const unsigned numLoops;
295  bool hasSparseOut;
296  std::vector<std::vector<Dim>> dims;
300 };
301 
302 } // namespace sparse_tensor
303 } // namespace mlir
304 
305 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
Kind
Tensor expression kind.
Definition: Merger.h:27
Include the generated interface declarations.
bool isOutTensor(unsigned b, unsigned i) const
Returns true if bit corresponds to index of output tensor.
Definition: Merger.h:227
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
unsigned tensor(unsigned b) const
Bit translation.
Definition: Merger.h:220
bool isDim(unsigned b, Dim d) const
Returns true if bit corresponds to queried dim.
Definition: Merger.h:224
unsigned addExp(Kind k, unsigned e, Value v, Operation *op=nullptr)
Definition: Merger.h:159
TensorExp & exp(unsigned e)
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:256
BitVector bits
Conjunction of tensor loop indices as bitvector.
Definition: Merger.h:130
unsigned exp
Index of the tensor expression.
Definition: Merger.h:138
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * op
Code blocks used by semirings.
Definition: Merger.h:118
Children subexpressions of tensor operations.
Definition: Merger.h:86
unsigned tensor
Expressions representing tensors simply have a tensor number.
Definition: Merger.h:100
Tensor expression. Represents a MLIR expression in tensor index notation.
Definition: Merger.h:92
Kind kind
Tensor expression kind.
Definition: Merger.h:96
LatPoint & lat(unsigned l)
Definition: Merger.h:257
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
void setHasSparseOut(bool s)
Definition: Merger.h:250
unsigned index(unsigned b) const
Definition: Merger.h:221
BitVector simple
Simplified conjunction of tensor loop indices as bitvector.
Definition: Merger.h:135
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
unsigned index
Indices hold the index number.
Definition: Merger.h:103
unsigned addExp(Kind k, Value v, Operation *op=nullptr)
Definition: Merger.h:162
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:112
void setDim(unsigned t, unsigned i, Dim d)
Dimension setter.
Definition: Merger.h:247
bool isDim(unsigned t, unsigned i, Dim d) const
Returns true if tensor access at given index has queried dim.
Definition: Merger.h:232
Merger(unsigned t, unsigned l)
Constructs a merger for the given number of tensors and loops.
Definition: Merger.h:152
A class to handle all iteration lattice operations.
Definition: Merger.h:145
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
Children children
Tensor operations hold the indices of their children.
Definition: Merger.h:106