MLIR  19.0.0git
CodegenEnv.h
Go to the documentation of this file.
1 //===- CodegenEnv.h - Code generation environment class ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines the code generation environment class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
15 
16 #include "CodegenUtils.h"
17 #include "LoopEmitter.h"
18 
23 #include <optional>
24 
25 namespace mlir {
26 namespace sparse_tensor {
27 
28 /// The code generation environment class aggregates a number of data
29 /// structures that are needed during the code generation phase of
30 /// sparsification. This environment simplifies passing around such
31 /// data during sparsification (rather than passing around all the
32 /// individual compoments where needed). Furthermore, it provides
33 /// convience methods that keep implementation details transparent
34 /// to sparsification while asserting on internal consistency.
35 class CodegenEnv {
36 public:
37  /// Constructs a code generation environment which can be
38  /// passed around during sparsification for bookkeeping
39  /// together with some consistency asserts.
40  CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
41  unsigned numTensors, unsigned numLoops, unsigned maxRank);
42 
43  //
44  // General methods.
45  //
46 
48  ExprId getExprId() const { return tensorExp; }
49 
50  linalg::GenericOp op() const { return linalgOp; }
51  const SparsificationOptions &options() const { return sparseOptions; }
52  Merger &merger() { return latticeMerger; }
53  LoopEmitter &emitter() { return loopEmitter; }
54 
55  void startEmit(SparseEmitStrategy emitStrategy);
56 
57  /// Generates loop boundary statements (entering/exiting loops). The function
58  /// passes and updates the passed-in parameters.
59  std::optional<Operation *>
61  std::optional<Operation *>(MutableArrayRef<Value> parameters)>
62  callback);
63 
64  //
65  // Merger delegates.
66  //
67 
68  constexpr TensorId makeTensorId(unsigned t) const {
69  return latticeMerger.makeTensorId(t);
70  }
71  constexpr LoopId makeLoopId(unsigned i) const {
72  return latticeMerger.makeLoopId(i);
73  }
74  constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
75  return latticeMerger.makeTensorLoopId(t, i);
76  }
77  const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
78  const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
79  ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
80  LevelType lt(TensorId t, LoopId i) const {
81  return latticeMerger.getLvlType(t, i);
82  }
83  LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
84 
85  unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
86 
87  //
88  // LoopEmitter delegates.
89  //
90 
92  // Make sure LoopEmitter, GenericOp, and Merger agree on the number of
93  // tensors.
94  assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
95  loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
96  loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
97  loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
98  return loopEmitter.makeTensorLevel(t, l);
99  }
100  TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
101  return makeTensorLevel(tlPair.first, tlPair.second);
102  }
103  std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
104  return loopEmitter.unpackTensorLevel(tl);
105  }
106  template <class ContainerTy>
107  auto unpackTensorLevelRange(ContainerTy &&c) const {
108  return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
109  }
110 
111  unsigned getCurrentDepth() const { return loopEmitter.getCurrentDepth(); }
112 
113  //
114  // Code generation environment verify functions.
115  //
116 
117  /// Whether the tensor expression is admissible for codegen.
118  /// It also sets the sparseOut if the output tensor is sparse.
120 
121  /// Returns the induction-variable for the given loop.
122  Value getLoopVar(LoopId i) const;
123 
124  //
125  // Sparse tensor output and expansion methods.
126  //
127 
128  bool hasSparseOutput() const { return sparseOut != nullptr; }
129  bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
130 
131  Value getInsertionChain() const { return insChain; }
132  void updateInsertionChain(Value chain);
133 
134  bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const;
135  void startExpand(Value values, Value filled, Value added, Value count);
136  bool isExpand() const { return expValues != nullptr; }
137  void updateExpandCount(Value count);
138  Value getExpandValues() const { return expValues; }
139  Value getExpandFilled() const { return expFilled; }
140  Value getExpandAdded() const { return expAdded; }
141  Value getExpandCount() const { return expCount; }
142  void endExpand();
143 
144  //
145  // Reduction methods.
146  //
147 
148  void startReduc(ExprId exp, Value val);
149  bool isReduc() const { return redExp != detail::kInvalidId; }
150  void updateReduc(Value val);
151  Value getReduc() const { return redVal; }
152  Value endReduc();
153 
154  void startValidLexInsert(Value val);
155  bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
156  void updateValidLexInsert(Value val);
157  Value getValidLexInsert() const { return redValidLexInsert; }
158  void endValidLexInsert();
159 
161  bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
162  Value getCustomRedId() const;
163  void endCustomReduc();
164 
165 private:
166  // Linalg operation.
167  linalg::GenericOp linalgOp;
168 
169  // Sparsification options.
170  SparsificationOptions sparseOptions;
171 
172  // Merger helper class.
173  Merger latticeMerger;
174 
175  // Loop emitter helper class.
176  LoopEmitter loopEmitter;
177 
178  // Sparse tensor as output. Implemented either through direct injective
179  // insertion in lexicographic index order or through access pattern
180  // expansion in the innermost loop nest (`expValues` through `expCount`).
181  OpOperand *sparseOut;
182  // The count of outer non-filter loops, as defined by `isAdmissibleTopoOrder`.
183  LoopId outerParNest;
184  Value insChain;
185  Value expValues;
186  Value expFilled;
187  Value expAdded;
188  Value expCount;
189 
190  // Bookkeeping for reductions (up-to-date value of the reduction, and indices
191  // into the merger's expression tree. When the indices of a tensor reduction
192  // expression are exhausted, all inner loops can use a scalarized reduction.
193  Value redVal;
194  ExprId redExp;
195  ExprId redCustom;
196 
197  // Bookkeeping for lex insertion during reductions. Holds the runtime boolean
198  // value of whether any reduction occurred. This is only set during a
199  // reduction and cleared once the reduction is finished.
200  Value redValidLexInsert;
201 
202  // The root tensor expression of the kernel.
203  ExprId tensorExp;
204 };
205 
206 } // namespace sparse_tensor
207 } // namespace mlir
208 
209 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
This class represents an operand of an operation.
Definition: Value.h:267
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:228
void updateValidLexInsert(Value val)
Definition: CodegenEnv.cpp:256
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:103
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:79
bool isAdmissibleTensorExp(ExprId e)
Whether the tensor expression is admissible for codegen.
Definition: CodegenEnv.cpp:136
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
Definition: CodegenEnv.cpp:200
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, unsigned numTensors, unsigned numLoops, unsigned maxRank)
Constructs a code generation environment which can be passed around during sparsification for bookkee...
Definition: CodegenEnv.cpp:44
unsigned getCurrentDepth() const
Definition: CodegenEnv.h:111
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:103
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:91
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:78
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:68
LevelType lt(TensorLoopId b) const
Definition: CodegenEnv.h:83
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:205
unsigned getLoopNum() const
Definition: CodegenEnv.h:85
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:195
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:266
TensorLevel makeTensorLevel(std::pair< TensorId, Level > tlPair) const
Definition: CodegenEnv.h:100
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Definition: CodegenEnv.h:74
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Definition: CodegenEnv.cpp:187
void startEmit(SparseEmitStrategy emitStrategy)
Definition: CodegenEnv.cpp:62
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:107
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:77
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:214
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:129
void startValidLexInsert(Value val)
Definition: CodegenEnv.cpp:251
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:71
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:80
TensorId getOutTensorId() const
Gets the TensorId for output tensor.
Definition: LoopEmitter.h:193
TensorLevel makeTensorLevel(TensorId t, Level l) const
Compresses a TensorId and Level into a TensorLevel.
Definition: LoopEmitter.h:199
unsigned getNumManifestTensors() const
Gets the total number of manifest tensors (excluding the synthetic tensor).
Definition: LoopEmitter.h:181
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const
De-compresses a TensorLevel back to a pair of TensorId and Level.
Definition: LoopEmitter.h:204
auto unpackTensorLevelRange(ContainerTy &&c) const
Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
Definition: LoopEmitter.h:211
unsigned getNumTensors() const
Gets the total number of tensors that loopEmitter is operating on.
Definition: LoopEmitter.h:184
LoopId getCurrentDepth() const
Gets the current depth of the loop-stack.
Definition: LoopEmitter.h:168
TensorId getSynTensorId() const
Gets the TensorId for synthetic tensor.
Definition: LoopEmitter.h:190
A class to handle all iteration lattice operations.
Definition: Merger.h:224
constexpr unsigned getNumLoops() const
Gets the total number of loops (native loops + filter loops).
Definition: Merger.h:354
constexpr unsigned getNumTensors() const
Gets the total number of tensors (including the output-tensor and synthetic-tensor).
Definition: Merger.h:351
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:248
ArrayRef< LatPointId > set(LatSetId s) const
Definition: Merger.h:548
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:366
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:540
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:544
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:362
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition: Merger.h:242
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:398
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition: Merger.h:254
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
unsigned TensorLevel
Definition: LoopEmitter.h:26
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:51
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for the Sparsification pass.
Definition: Passes.h:91
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67