MLIR  16.0.0git
Merger.h
Go to the documentation of this file.
1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with iteration lattices.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15 
18 #include "mlir/IR/Value.h"
19 #include "llvm/ADT/BitVector.h"
20 
21 namespace mlir {
22 namespace sparse_tensor {
23 
24 /// Tensor expression kind.
25 enum Kind {
26  // Leaf.
27  kTensor = 0,
30  // Unary operations.
51  kCastFS, // signed
52  kCastFU, // unsigned
53  kCastSF, // signed
54  kCastUF, // unsigned
55  kCastS, // signed
56  kCastU, // unsigned
59  kCIm, // complex.im
60  kCRe, // complex.re
62  kBinaryBranch, // semiring unary branch created from a binary op
63  kUnary, // semiring unary op
64  kSelect, // custom selection criteria
65  // Binary operations.
70  kDivC, // complex
71  kDivS, // signed
72  kDivU, // unsigned
82  kShrS, // signed
83  kShrU, // unsigned
85  kBinary, // semiring binary op
86  kReduce, // semiring reduction op
87 };
88 
89 /// Children subexpressions of tensor operations.
90 struct Children {
91  unsigned e0;
92  unsigned e1;
93 };
94 
95 /// Tensor expression. Represents a MLIR expression in tensor index notation.
96 struct TensorExp {
97  TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation);
98 
99  /// Tensor expression kind.
101 
102  union {
103  /// Expressions representing tensors simply have a tensor number.
104  unsigned tensor;
105 
106  /// Indices hold the index number.
107  unsigned index;
108 
109  /// Tensor operations hold the indices of their children.
111  };
112 
113  /// Direct link to IR for an invariant or the destination value (to
114  /// infer destination type) of a cast operation During code generation,
115  /// this field may be used to cache "hoisted" loop invariant tensor loads.
117 
118  /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
119  /// and kSelect, this holds the original operation with all regions. For
120  /// kBinaryBranch, this holds the YieldOp for the left or right half
121  /// to be merged into a nested scf loop.
123 };
124 
125 /// Lattice point. Each lattice point consists of a conjunction of tensor
126 /// loop indices (encoded in a bitvector) and the index of the corresponding
127 /// tensor expression.
128 struct LatPoint {
129  LatPoint(unsigned n, unsigned e, unsigned b);
130  LatPoint(const BitVector &b, unsigned e);
131 
132  /// Conjunction of tensor loop indices as bitvector. This represents
133  /// all indices involved in the tensor expression
134  BitVector bits;
135 
136  /// Simplified conjunction of tensor loop indices as bitvector. This
137  /// represents a simplified condition under which this tensor expression
138  /// must execute. Pre-computed during codegen to avoid repeated eval.
139  BitVector simple;
140 
141  /// Index of the tensor expression.
142  unsigned exp;
143 };
144 
145 /// A class to handle all iteration lattice operations. This class abstracts
146 /// away from some implementation details of storing iteration lattices and
147 /// tensor expressions. This allows for fine-tuning performance characteristics
148 /// independently from the basic algorithm if bottlenecks are identified.
149 class Merger {
150 public:
151  /// Constructs a merger for the given number of tensors, native loops, and
152  /// filter loops. The user supplies the number of tensors involved in the
153  /// kernel, with the last tensor in this set denoting the output tensor. The
154  /// merger adds an additional synthetic tensor at the end of this set to
155  /// represent all invariant expressions in the kernel.
156  /// In addition to natives
157  /// loops (which are specified by the GenericOp), extra filter loops are
158  /// needed in order to handle affine expressions on sparse dimensions.
159  /// E.g., (d0, d1, d2) => (d0 + d1, d2), a naive implementation of the filter
160  /// loop could be generated as:
161  /// for (coord : sparse_dim[0])
162  /// if (coord == d0 + d1) {
163  /// generated_code;
164  /// }
165  /// }
166  /// to filter out coordinates that are not equal to the affine expression
167  /// result.
168  /// TODO: we want to make the filter loop more efficient in the future, e.g.,
169  /// by avoiding scanning the full stored index sparse (keeping the last
170  /// position in ordered list) or even apply binary search to find the index.
171  Merger(unsigned t, unsigned l, unsigned fl)
172  : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1),
173  numNativeLoops(l), numLoops(l + fl), hasSparseOut(false),
174  dimTypes(numTensors,
175  std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
176  loopIdxToDim(numTensors,
177  std::vector<Optional<unsigned>>(numLoops, std::nullopt)),
178  dimToLoopIdx(numTensors,
179  std::vector<Optional<unsigned>>(numLoops, std::nullopt)) {}
180 
181  /// Adds a tensor expression. Returns its index.
182  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
183  Operation *op = nullptr);
184  unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) {
185  return addExp(k, e, -1u, v, op);
186  }
187  unsigned addExp(Kind k, Value v, Operation *op = nullptr) {
188  return addExp(k, -1u, -1u, v, op);
189  }
190 
191  /// Adds an iteration lattice point. Returns its index.
192  unsigned addLat(unsigned t, unsigned i, unsigned e);
193 
194  /// Adds a new, initially empty, set. Returns its index.
195  unsigned addSet();
196 
197  /// Computes a single conjunction of two lattice points by taking the "union"
198  /// of loop indices (effectively constructing a larger "intersection" of those
199  /// indices) with a newly constructed tensor (sub)expression of given kind.
200  /// Returns the index of the new lattice point.
201  unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1,
202  Operation *op = nullptr);
203 
204  /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
205  /// cartesian product. Returns the index of the new set.
206  unsigned takeConj(Kind kind, unsigned s0, unsigned s1,
207  Operation *op = nullptr);
208 
209  /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
210  /// Returns the index of the new set.
211  unsigned takeDisj(Kind kind, unsigned s0, unsigned s1,
212  Operation *op = nullptr);
213 
214  /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of
215  /// the overlap, left, and right regions. Any region may be left missing in
216  /// the output. Returns the index of the new set.
217  unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
218  bool includeLeft, Kind ltrans, Operation *opleft,
219  bool includeRight, Kind rtrans, Operation *opright);
220 
221  /// Maps the unary operator over the lattice set of the operand, i.e. each
222  /// lattice point on an expression E is simply copied over, but with OP E
223  /// as new expression. Returns the index of the new set.
224  unsigned mapSet(Kind kind, unsigned s0, Value v = Value(),
225  Operation *op = nullptr);
226 
227  /// Optimizes the iteration lattice points in the given set. This
228  /// method should be called right before code generation to avoid
229  /// generating redundant loops and conditions.
230  unsigned optimizeSet(unsigned s0);
231 
232  /// Simplifies the conditions in a conjunction of a given lattice point
233  /// within the given set using just two basic rules:
234  /// (1) multiple dense conditions are reduced to single dense, and
235  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
236  BitVector simplifyCond(unsigned s0, unsigned p0);
237 
238  /// Returns true if Li > Lj.
239  bool latGT(unsigned i, unsigned j) const;
240 
241  /// Returns true if Li and Lj only differ in dense.
242  bool onlyDenseDiff(unsigned i, unsigned j);
243 
244  /// Bit translation (get tensor ID).
245  unsigned tensor(unsigned b) const { return b % numTensors; }
246  /// Bit translation (get loop index).
247  unsigned index(unsigned b) const { return b / numTensors; }
248 
249  /// Get the number of total loops (native loops + filter loops).
250  unsigned getNumLoops() const { return numLoops; }
251  /// Get the number of native loops.
252  unsigned getNumNativeLoops() const { return numNativeLoops; }
253  /// Get the number of filter loops.
254  unsigned getNumFilterLoops() const { return numLoops - numNativeLoops; }
255  /// Get the starting filter loop index.
256  unsigned getFilterLoopStartingIdx() const { return getNumNativeLoops(); }
257 
258  /// Returns true if bit corresponds to index of output tensor.
259  bool isOutTensor(unsigned b, unsigned i) const {
260  return tensor(b) == outTensor && index(b) == i;
261  }
262 
263  /// Gets tensor ID for the output tensor.
264  unsigned getOutTensorID() const { return outTensor; }
265  /// Gets tensor ID for the synthetic tensor (used for all invariant tensor
266  /// expressions).
267  unsigned getSynTensorID() const { return syntheticTensor; }
268 
269  bool isFilterLoop(unsigned ldx) const {
270  assert(ldx < numLoops);
271  return ldx >= numNativeLoops;
272  }
273 
274  /// Returns true if given tensor iterates *only* in the given tensor
275  /// expression. For the output tensor, this defines a "simply dynamic"
276  /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
277  /// sparse vector a.
278  bool isSingleCondition(unsigned t, unsigned e) const;
279 
280  /// Returns true if any set bit corresponds to sparse dimension level type.
281  bool hasAnySparse(const BitVector &bits) const;
282 
283  /// Gets the dimension level type of the `t`th tensor on `i`th loop.
284  DimLevelType getDimLevelType(unsigned t, unsigned i) const {
285  assert(t < numTensors && i < numLoops);
286  return dimTypes[t][i];
287  }
288 
289  /// Gets the dimension level type of `b`.
290  DimLevelType getDimLevelType(unsigned b) const {
291  return getDimLevelType(tensor(b), index(b));
292  }
293 
294  Optional<unsigned> getLoopIdx(unsigned t, unsigned dim) const {
295  assert(t < numTensors && dim < numLoops);
296  return dimToLoopIdx[t][dim];
297  }
298 
299  /// Gets the dimension number of the the `t`th tensor on `i`th loop.
300  Optional<unsigned> getDimNum(unsigned t, unsigned i) const {
301  assert(t < numTensors && i < numLoops);
302  return loopIdxToDim[t][i];
303  }
304 
305  /// Gets the dimension number of `b`.
306  Optional<unsigned> getDimNum(unsigned b) const {
307  return getDimNum(tensor(b), index(b));
308  }
309 
310  /// Sets the dimension and dimension level type of the `t`th tensor on `i`th
311  /// loop.
312  void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim,
313  DimLevelType dlt) {
314  assert(isValidDLT(dlt));
315  dimTypes[t][i] = dlt;
316  loopIdxToDim[t][i] = dim;
317  assert(dim < numLoops);
318  dimToLoopIdx[t][dim] = i;
319  }
320 
321  // Iterates the bits of a lattice, for each set bit, converts it into the
322  // corresponding tensor dimension and invokes the callback.
324  const BitVector &bits,
325  function_ref<void(unsigned b, unsigned tid, Optional<unsigned> dim,
326  DimLevelType dlt)>
327  cb) {
328  for (unsigned b : bits.set_bits())
329  cb(b, tensor(b), getDimNum(b), getDimLevelType(b));
330  }
331 
332  // Has sparse output tensor setter.
333  void setHasSparseOut(bool s) { hasSparseOut = s; }
334 
335  /// Convenience getters to immediately access the stored nodes.
336  /// Typically it is inadvisible to keep the reference around, as in
337  /// "TensorExpr &te = merger.exp(e))", since insertions into the merger
338  /// may cause data movement and invalidate the underlying memory address.
339  TensorExp &exp(unsigned e) { return tensorExps[e]; }
340  LatPoint &lat(unsigned l) { return latPoints[l]; }
341  SmallVector<unsigned> &set(unsigned s) { return latSets[s]; }
342 
343 #ifndef NDEBUG
344  /// Print methods (for debugging).
345  void dumpExp(unsigned e) const;
346  void dumpLat(unsigned p) const;
347  void dumpSet(unsigned s) const;
348  void dumpBits(const BitVector &bits) const;
349 #endif
350 
351  /// Builds the iteration lattices in a bottom-up traversal given the remaining
352  /// tensor (sub)expression and the next loop index in the iteration graph.
353  /// Returns index of the root expression.
354  unsigned buildLattices(unsigned e, unsigned i);
355 
356  /// Builds a tensor expression from the given Linalg operation.
357  /// Returns index of the root expression on success.
358  Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
359 
360  /// Rebuilds SSA format from a tensor expression.
361  Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0,
362  Value v1);
363 
364 private:
365  /// Private helpers.
366  bool maybeZero(unsigned e) const;
367  bool isInvariant(unsigned e) const;
368  Type inferType(unsigned e, Value src);
369 
370  /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
371  Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
372 
373  /// Merger data structures.
374  const unsigned outTensor;
375  const unsigned syntheticTensor;
376  const unsigned numTensors;
377  const unsigned numNativeLoops;
378  const unsigned numLoops;
379  bool hasSparseOut;
380  // Map that converts pair<tensor id, loop id> to the corresponding dimension
381  // level type.
382  std::vector<std::vector<DimLevelType>> dimTypes;
383  // Map that converts pair<tensor id, loop id> to the corresponding dimension.
384  std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
385  // Map that converts pair<tensor id, dim> to the corresponding loop id.
386  std::vector<std::vector<Optional<unsigned>>> dimToLoopIdx;
387  llvm::SmallVector<TensorExp> tensorExps;
388  llvm::SmallVector<LatPoint> latPoints;
390 };
391 
392 } // namespace sparse_tensor
393 } // namespace mlir
394 
395 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
A class to handle all iteration lattice operations.
Definition: Merger.h:149
unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, bool includeLeft, Kind ltrans, Operation *opleft, bool includeRight, Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets L0 and L1 with custom handling of the overlap,...
Definition: Merger.cpp:194
void setHasSparseOut(bool s)
Definition: Merger.h:333
Optional< unsigned > getLoopIdx(unsigned t, unsigned dim) const
Definition: Merger.h:294
unsigned addExp(Kind k, unsigned e, Value v, Operation *op=nullptr)
Definition: Merger.h:184
bool isFilterLoop(unsigned ldx) const
Definition: Merger.h:269
unsigned optimizeSet(unsigned s0)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:226
bool latGT(unsigned i, unsigned j) const
Returns true if Li > Lj.
Definition: Merger.cpp:294
unsigned buildLattices(unsigned e, unsigned i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:638
unsigned addExp(Kind k, unsigned e0, unsigned e1=-1u, Value v=Value(), Operation *op=nullptr)
Adds a tensor expression. Returns its index.
Definition: Merger.cpp:137
bool isOutTensor(unsigned b, unsigned i) const
Returns true if bit corresponds to index of output tensor.
Definition: Merger.h:259
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:611
unsigned getFilterLoopStartingIdx() const
Get the starting filter loop index.
Definition: Merger.h:256
void dumpSet(unsigned s) const
Definition: Merger.cpp:602
LatPoint & lat(unsigned l)
Definition: Merger.h:340
unsigned getNumLoops() const
Get the number of total loops (native loops + filter loops).
Definition: Merger.h:250
unsigned getNumFilterLoops() const
Get the number of filter loops.
Definition: Merger.h:254
void dumpExp(unsigned e) const
Print methods (for debugging).
Definition: Merger.cpp:508
unsigned mapSet(Kind kind, unsigned s0, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:215
void foreachTidDimPairInBits(const BitVector &bits, function_ref< void(unsigned b, unsigned tid, Optional< unsigned > dim, DimLevelType dlt)> cb)
Definition: Merger.h:323
DimLevelType getDimLevelType(unsigned t, unsigned i) const
Gets the dimension level type of the tth tensor on ith loop.
Definition: Merger.h:284
TensorExp & exp(unsigned e)
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:339
Merger(unsigned t, unsigned l, unsigned fl)
Constructs a merger for the given number of tensors, native loops, and filter loops.
Definition: Merger.h:171
Optional< unsigned > getDimNum(unsigned b) const
Gets the dimension number of b.
Definition: Merger.h:306
bool onlyDenseDiff(unsigned i, unsigned j)
Returns true if Li and Lj only differ in dense.
Definition: Merger.cpp:307
unsigned addExp(Kind k, Value v, Operation *op=nullptr)
Definition: Merger.h:187
SmallVector< unsigned > & set(unsigned s)
Definition: Merger.h:341
void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim, DimLevelType dlt)
Sets the dimension and dimension level type of the tth tensor on ith loop.
Definition: Merger.h:312
unsigned getOutTensorID() const
Gets tensor ID for the output tensor.
Definition: Merger.h:264
DimLevelType getDimLevelType(unsigned b) const
Gets the dimension level type of b.
Definition: Merger.h:290
unsigned addSet()
Adds a new, initially empty, set. Returns its index.
Definition: Merger.cpp:151
unsigned getSynTensorID() const
Gets tensor ID for the synthetic tensor (used for all invariant tensor expressions).
Definition: Merger.h:267
unsigned getNumNativeLoops() const
Get the number of native loops.
Definition: Merger.h:252
unsigned tensor(unsigned b) const
Bit translation (get tensor ID).
Definition: Merger.h:245
void dumpLat(unsigned p) const
Definition: Merger.cpp:592
unsigned takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets L0 and L1 is conjunction of cartesian product.
Definition: Merger.cpp:167
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of loop indices (effectivel...
Definition: Merger.cpp:157
Optional< unsigned > getDimNum(unsigned t, unsigned i) const
Gets the dimension number of the the tth tensor on ith loop.
Definition: Merger.h:300
Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1)
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1122
bool hasAnySparse(const BitVector &bits) const
Returns true if any set bit corresponds to sparse dimension level type.
Definition: Merger.cpp:397
unsigned index(unsigned b) const
Bit translation (get loop index).
Definition: Merger.h:247
Optional< unsigned > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:834
unsigned addLat(unsigned t, unsigned i, unsigned e)
Adds an iteration lattice point. Returns its index.
Definition: Merger.cpp:144
bool isSingleCondition(unsigned t, unsigned e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:313
BitVector simplifyCond(unsigned s0, unsigned p0)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:255
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets L0 and L1 is (L0 /_op L1, L0, L1).
Definition: Merger.cpp:175
Kind
Tensor expression kind.
Definition: Merger.h:25
constexpr bool isValidDLT(DimLevelType dlt)
Check that the DimLevelType contains a valid (possibly undefined) value.
Definition: Enums.h:161
DimLevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:147
Include the generated interface declarations.
Children subexpressions of tensor operations.
Definition: Merger.h:90
LatPoint(unsigned n, unsigned e, unsigned b)
Definition: Merger.cpp:126
BitVector bits
Conjunction of tensor loop indices as bitvector.
Definition: Merger.h:134
BitVector simple
Simplified conjunction of tensor loop indices as bitvector.
Definition: Merger.h:139
unsigned exp
Index of the tensor expression.
Definition: Merger.h:142
Tensor expression. Represents a MLIR expression in tensor index notation.
Definition: Merger.h:96
TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation)
Definition: Merger.cpp:25
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:116
Children children
Tensor operations hold the indices of their children.
Definition: Merger.h:110
Kind kind
Tensor expression kind.
Definition: Merger.h:100
unsigned index
Indices hold the index number.
Definition: Merger.h:107
unsigned tensor
Expressions representing tensors simply have a tensor number.
Definition: Merger.h:104
Operation * op
Code blocks used by semirings.
Definition: Merger.h:122
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.