MLIR  19.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenEnv.h"
14 #include "Utils/CodegenUtils.h"
15 #include "Utils/LoopEmitter.h"
16 
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TensorEncoding.h"
36 #include "llvm/ADT/SmallBitVector.h"
37 
38 #include <optional>
39 
40 using namespace mlir;
41 using namespace mlir::sparse_tensor;
42 
43 //===----------------------------------------------------------------------===//
44 // Sparsifier analysis methods.
45 //===----------------------------------------------------------------------===//
46 
47 /// Returns true iff affine expression is invariant. Sets the
48 /// parameter `isCurrentLoop` when expression just became invariant.
49 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50  switch (a.getKind()) {
51  case AffineExprKind::DimId: {
52  const LoopId i = cast<AffineDimExpr>(a).getPosition();
53  if (i + 1 == curr) {
54  isCurrentLoop = true;
55  return true; // becomes invariant at current loop
56  }
57  return i < curr; // invariant when already generated
58  }
60  case AffineExprKind::Mul: {
61  auto binOp = cast<AffineBinaryOpExpr>(a);
62  return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
63  isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
64  }
65  default: {
66  assert(isa<AffineConstantExpr>(a));
67  return true;
68  }
69  }
70 }
71 
72 /// Helper method to inspect affine expressions. Rejects cases where the
73 /// same index is used more than once. Also rejects compound affine
74 /// expressions in sparse dimensions.
75 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76  LevelType lt, bool setLvlFormat = true) {
77  switch (a.getKind()) {
78  case AffineExprKind::DimId: {
79  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
80  if (!isUndefLT(merger.getLvlType(tid, idx)))
81  return false; // used more than once
82  if (setLvlFormat)
83  merger.setLevelAndType(tid, idx, lvl, lt);
84  return true;
85  }
89  assert(lt.hasDenseSemantic());
90  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
91  // We do not set dim level format for affine expression like d0 + d1 on
92  // either loop index at d0 or d1. We continue the recursion merely to
93  // check whether current affine is admissible or not.
94  return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
95  findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
96  }
97  // Falls through when it is a constant Affine
98  return true;
99  }
100  default:
101  return false;
102  }
103 }
104 
105 /// Helper method to inspect affine expressions for index variable reduction
106 /// based codegen. It finds the dependent index set for all tensor levels in the
107 /// current expression we are generating.
108 ///
109 /// For example, when handling A[i+j][j+k], we build the two way mapping in
110 /// merger between (tensor, level) pairs and their dependent index variable set:
111 /// A_0 <=> [i, j] and A_1 <=> [j, k]
112 ///
113 /// It rejects cases (returns false)
114 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
115 /// 2nd, when multiplication is used in the non-trivial index expression.
116 /// 3rd, when a constant operand is used in the non-trivial index expression.
117 ///
118 /// TODO: constant should be easy to handle.
119 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120  AffineExpr a, LevelType lt, bool isSubExp = false,
121  int64_t coefficient = 1) {
122  switch (a.getKind()) {
123  case AffineExprKind::DimId: {
124  // Only allow positive coefficients on AffineDimExpr.
125  if (coefficient <= 0)
126  return false;
127 
128  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
129  if (!isUndefLT(merger.getLvlType(tensor, idx)))
130  return false; // used more than once, e.g., A[i][i]
131 
132  // TODO: Generalizes the following two cases. A[i] (with trivial index
133  // expression) can be treated as a special affine index expression. We do
134  // not necessarily need to differentiate them.
135  if (!isSubExp) {
136  assert(coefficient == 1);
137  merger.setLevelAndType(tensor, idx, lvl, lt);
138  }
139 
140  if (isSubExp) {
141  // The current loops appears in more than one affine expressions on the
142  // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143  // used twice.
144  if (merger.hasDependentLvl(idx, tensor)) {
145  // TODO: This can be supported by coiterate slices if the loop idx is
146  // appeared on affine index for different tensor, or take slice on
147  // multiple dimensions when it is on the same tensor.
148  // E.g.,
149  // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150  // d0_1 = getNextSliceOffset t0 along lvl0
151  // d0_2 = getNextSliceOffset t1 along lvl0
152  // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153  // else increase min(d0_1, d0_2).
154  return false;
155  }
156  merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
157  }
158  return true;
159  }
161  case AffineExprKind::Mul: {
162  // TODO: Support index expression like `2 * d0`, we now only support more
163  // complicated cases like `2 * d0 + d1`.
164  if (!isSubExp)
165  return false;
166 
167  // TODO: Support Constant AffineExp for slice-based codegen
168  if (isa<AffineConstantExpr>(a))
169  llvm_unreachable("Not yet implemented");
170 
171  auto binOp = cast<AffineBinaryOpExpr>(a);
172  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173  if (isa<AffineConstantExpr>(rhs))
174  std::swap(lhs, rhs);
175  // Must be in form of `constant * d`.
176  assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177  int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
178  return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
179  }
180  case AffineExprKind::Add: {
181  auto binOp = cast<AffineBinaryOpExpr>(a);
182  return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
183  findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
184  }
185  default:
186  return false;
187  }
188 }
189 
190 /// Gets the total number of compound affine expressions in the
191 /// `getMatchingIndexingMap` for the given tensor. For the following inputs:
192 ///
193 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194 ///
195 /// Returns 1 (because the first level is compressed and its corresponding
196 /// indexing-expression is `d0 + d1`)
198  Value tensor) {
199  // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200  // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201  // However, we don't need to handle `StorageSpecifierType`, so we
202  // can use `SparseTensorType` once we guard against non-tensors.
203  const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204  if (!rtp)
205  return 0;
206  const SparseTensorType stt(rtp);
207 
208  const Level lvlRank = stt.getLvlRank();
209  const auto exprs = map.getResults();
210  assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211  "AffineMap does not have dimension-rank many results");
212  unsigned num = 0;
213  for (Level l = 0; l < lvlRank; l++) {
214  if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215  num++;
216  }
217  return num;
218 }
219 
220 /// Gets the total number of sparse levels with compound affine
221 /// expressions, summed over all operands of the `GenericOp`.
222 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223  unsigned num = 0;
224  for (OpOperand &t : op->getOpOperands())
225  num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226  t.get());
227  return num;
228 }
229 
230 // Returns true iff output has nontrivial affine indices.
231 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232  OpOperand *out = op.getDpsInitOperand(0);
233  if (getSparseTensorType(out->get()).isAllDense())
234  return false;
235  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236  out->get());
237 }
238 
239 /// Helper method to inspect sparse encodings in the tensor types.
240 /// Fills the per-dimension sparsity information for all tensors.
241 /// Returns true if the sparse annotations and affine subscript
242 /// expressions of all tensors are admissible. Returns false if
243 /// no annotations are found or inadmissible constructs occur.
244 /// We currently support two different ways to handle non-trivial index
245 /// expression on sparse tensors, and they accept different affine expressions.
246 /// When using dependent index reducton-based approach, it currently only
247 /// supports affine addition index expression.
248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249  bool annotated = false;
250  for (OpOperand &t : env.op()->getOpOperands()) {
251  const TensorId tid = env.makeTensorId(t.getOperandNumber());
252  const auto map = env.op().getMatchingIndexingMap(&t);
253  const auto enc = getSparseTensorEncoding(t.get().getType());
254  if (enc)
255  annotated = true;
256  const Level lvlRank = map.getNumResults();
257  assert(!enc || lvlRank == enc.getLvlRank());
258  assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259  // We only need to do index reduction if there is at least one
260  // non-trivial index expression on sparse levels. If all non-trivial
261  // index expression is on dense levels, we can efficiently rely on
262  // the random access to locate the element.
263  bool needIdxReduc =
264  enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265  // If then current tensor being inspected requires affine index, it need
266  // to be sliced.
267  for (Level l = 0; l < lvlRank; l++) {
268  const AffineExpr a = map.getResult(l);
269  const LevelType lt = enc.getLvlType(l);
270  if (idxReducBased && needIdxReduc) {
271  if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272  return false; // inadmissible affine expression
273  } else {
274  if (!findAffine(env.merger(), tid, l, a, lt))
275  return false; // inadmissible affine expression
276  }
277  }
278  }
279  return annotated;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Sparsifier synthesis methods (statements and expressions).
284 //===----------------------------------------------------------------------===//
285 
286 /// Local bufferization of all dense and sparse data structures.
287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288  linalg::GenericOp op = env.op();
289  Location loc = op.getLoc();
290  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291 
292  SmallVector<Range, 4> loopRange =
293  llvm::cast<linalg::LinalgOp>(op.getOperation())
294  .createLoopRanges(builder, loc);
295 
297  builder, loc,
298  /// Generates buffer for the output tensor.
299  /// Note that all sparse kernels assume that when all elements are written
300  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301  /// to all zeroes and only nonzeroes values are computed and written out.
302  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303  /// for the updates and no assumption on the original contents of the
304  /// output buffer is necessary.
305  [&op](OpBuilder &builder, Location loc, Value memref,
306  Value tensor) -> Value {
307  // Must not be a sparse tensor.
308  assert(!getSparseTensorEncoding(tensor.getType()));
309  // Two output tensor references should point to the same object.
310  OpOperand *lhs = op.getDpsInitOperand(0);
311  assert(lhs->get() == tensor);
312  // An output tensor can simply materialize from the buffer of the tensor
313  // that appears in the outs() clause. For updates, this has the
314  // advantage that only the nonzero value are involved in the
315  // computation, keeping the operation O(nnz). In all other cases, we are
316  // forced to zero out the buffer to enforce the assumption above, which
317  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318  // O(nnz) for matrices).
319  // TODO: use better analysis to avoid zeroing out the buffer?
320  bool isInit = op.isInitTensor(lhs);
321  Value init = memref;
322  if (!isInit) {
323  Value zero = constantZero(builder, loc,
324  getElementTypeOrSelf(tensor.getType()));
325  builder.create<linalg::FillOp>(loc, ValueRange{zero},
326  ValueRange{init});
327  }
328  return init;
329  },
330  [&loopRange](OpBuilder &b, Location loc, Level l) {
331  assert(l < loopRange.size());
332  return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333  });
334 }
335 
336 /// Generates index for load/store on sparse tensor.
337 static Value genIndex(CodegenEnv &env, OpOperand *t) {
338  const auto map = env.op().getMatchingIndexingMap(t);
339  const auto stt = getSparseTensorType(t->get());
340  const Level lvlRank = stt.getLvlRank();
341  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342  const AffineExpr a = map.getResult(lvlRank - 1);
343  assert(a.getKind() == AffineExprKind::DimId);
344  const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345  return env.getLoopVar(idx);
346 }
347 
348 /// Generates subscript for load/store on a dense or sparse tensor.
349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350  SmallVectorImpl<Value> &args) {
351  const Location loc = env.op().getLoc();
352  const TensorId tid = env.makeTensorId(t->getOperandNumber());
353  const auto map = env.op().getMatchingIndexingMap(t);
354  const auto stt = getSparseTensorType(t->get());
355  if (stt.hasEncoding()) {
356  // For sparse tensors we only push the last-level's position onto `args`.
357  const auto pos = env.emitter().getValPosits(tid);
358  assert(!pos.empty());
359  args.append(pos);
360  } else {
361  // For dense tensors we push all level's coordinates onto `args`.
362  const Level lvlRank = stt.getLvlRank();
363  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
364  for (Level l = 0; l < lvlRank; l++) {
365  const auto lvlExpr = map.getResult(l);
366  const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
367  args.push_back(lvlCrd);
368  }
369  }
370  return env.emitter().getValBuffer()[tid];
371 }
372 
373 /// Generates insertion code to implement dynamic tensor load.
375  OpOperand *t) {
376  linalg::GenericOp op = env.op();
377  Location loc = op.getLoc();
378  // Direct lexicographic coordinate order, tensor loads as zero.
379  if (!env.isExpand()) {
380  Type tp = getElementTypeOrSelf(t->get().getType());
381  return constantZero(builder, loc, tp);
382  }
383  // Load from expanded access pattern.
384  Value index = genIndex(env, t);
385  return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
386 }
387 
388 /// Generates insertion code to implement dynamic tensor load for reduction.
390  OpOperand *t) {
391  linalg::GenericOp op = env.op();
392  Location loc = op.getLoc();
393  Value identity = env.getCustomRedId();
394  // Direct lexicographic coordinate order, tensor loads as identity.
395  if (!env.isExpand())
396  return identity;
397  // Load from expanded access pattern if filled, identity otherwise.
398  Value values = env.getExpandValues();
399  Value filled = env.getExpandFilled();
400  Value index = genIndex(env, t);
401  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
402  Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
403  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
404 }
405 
406 /// Generates insertion code to implement dynamic tensor store.
407 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
408  Value rhs) {
409  linalg::GenericOp op = env.op();
410  Location loc = op.getLoc();
411  // Direct insertion in lexicographic coordinate order.
412  if (!env.isExpand()) {
413  const LoopId numLoops = op.getRank(t);
414  // Retrieves the first `numLoop` induction variables.
415  SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
416  env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
417  Value chain = env.getInsertionChain();
418  if (env.isValidLexInsert()) {
419  // Generates runtime check for a valid lex during reduction,
420  // to avoid inserting the identity value for empty reductions.
421  // if (validLexInsert) then
422  // insert(rhs) into chain
423  // return updated chain
424  // else
425  // return unmodified chain
426  scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
427  loc, chain.getType(), env.getValidLexInsert(),
428  /*else=*/true);
429  // True branch.
430  builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
431  Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
432  builder.create<scf::YieldOp>(loc, res);
433  // False branch.
434  builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
435  builder.create<scf::YieldOp>(loc, chain);
436  // Value assignment.
437  builder.setInsertionPointAfter(ifValidLexInsert);
438  env.updateInsertionChain(ifValidLexInsert.getResult(0));
439  } else {
440  // Generates regular insertion chain.
442  builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
443  }
444  return;
445  }
446  // Generates insertion code along expanded access pattern.
447  // if (!expFilled[i]) then
448  // expFilled[i] = true
449  // expAdded[inserts++] = i
450  // endif
451  // values[i] = rhs
452  Value values = env.getExpandValues();
453  Value filled = env.getExpandFilled();
454  Value added = env.getExpandAdded();
455  Value count = env.getExpandCount();
456  Value index = genIndex(env, t);
457  Value fval = constantI1(builder, loc, false);
458  Value tval = constantI1(builder, loc, true);
459  // If statement.
460  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
461  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
462  isFilled, fval);
463  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
464  /*else=*/true);
465  // True branch.
466  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
467  builder.create<memref::StoreOp>(loc, tval, filled, index);
468  builder.create<memref::StoreOp>(loc, index, added, count);
469  Value one = constantIndex(builder, loc, 1);
470  Value add = builder.create<arith::AddIOp>(loc, count, one);
471  builder.create<scf::YieldOp>(loc, add);
472  // False branch.
473  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
474  builder.create<scf::YieldOp>(loc, count);
475  builder.setInsertionPointAfter(ifOp);
476  // Value assignment.
477  env.updateExpandCount(ifOp.getResult(0));
478  builder.create<memref::StoreOp>(loc, rhs, values, index);
479 }
480 
481 /// Generates a load on a dense or sparse tensor.
482 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
483  // Test if the load was hoisted to a higher loop nest.
484  Value val = env.exp(exp).val;
485  if (val)
486  return val;
487  // Load during insertion.
488  linalg::GenericOp op = env.op();
489  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
490  if (env.isSparseOutput(t)) {
491  if (env.isCustomReduc())
492  return genInsertionLoadReduce(env, builder, t);
493  return genInsertionLoad(env, builder, t);
494  }
495  // Actual load.
496  SmallVector<Value> args;
497  Value ptr = genSubscript(env, builder, t, args);
498  return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
499 }
500 
501 /// Generates a store on a dense or sparse tensor.
502 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
503  Value rhs) {
504  // Only unary and binary are allowed to return an uninitialized rhs
505  // to indicate missing output. Or otherwise a custom reduction that
506  // received no value to accumulate.
507  if (!rhs) {
508  assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
509  env.exp(exp).kind == TensorExp::Kind::kBinary ||
510  env.exp(exp).kind == TensorExp::Kind::kReduce);
511  return;
512  }
513  // Test if this is a scalarized reduction.
514  if (env.isReduc()) {
515  env.updateReduc(rhs);
516  return;
517  }
518  // Regular store.
519  linalg::GenericOp op = env.op();
520  Location loc = op.getLoc();
521  OpOperand *t = op.getDpsInitOperand(0);
522  if (!env.isSparseOutput(t)) {
523  SmallVector<Value> args;
524  Value ptr = genSubscript(env, builder, t, args);
525  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
526  return;
527  }
528  // Store during sparse insertion.
529  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
530  genInsertionStore(env, builder, t, rhs);
531  return;
532  }
533  // Select operation insertion.
534  Value chain = env.getInsertionChain();
535  scf::IfOp ifOp =
536  builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
537  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
538  // Existing value was preserved to be used here.
539  assert(env.exp(exp).val);
540  Value v0 = env.exp(exp).val;
541  genInsertionStore(env, builder, t, v0);
542  env.merger().clearExprValue(exp);
543  // Yield modified insertion chain along true branch.
544  Value mchain = env.getInsertionChain();
545  builder.create<scf::YieldOp>(op.getLoc(), mchain);
546  // Yield original insertion chain along false branch.
547  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
548  builder.create<scf::YieldOp>(loc, chain);
549  // Done with if statement.
550  env.updateInsertionChain(ifOp->getResult(0));
551  builder.setInsertionPointAfter(ifOp);
552 }
553 
554 /// Generates an invariant value.
555 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
556  return env.exp(exp).val;
557 }
558 
559 /// Semi-ring branches are simply inlined by the sparsifier. Prior
560 /// analysis has verified that all computations are "local" to the inlined
561 /// branch or otherwise invariantly defined outside the loop nest, with the
562 /// exception of index computations, which need to be relinked to actual
563 /// inlined cloned code.
564 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
565  Value e) {
566  if (auto arg = dyn_cast<BlockArgument>(e)) {
567  // Direct arguments of the original linalg op must be converted
568  // into dense tensor loads. Note that we should not encounter
569  // anything else. This needs to be verified by semi-ring ops.
570  linalg::GenericOp op = env.op();
571  if (arg.getOwner()->getParentOp() == op) {
572  const TensorId tid = env.makeTensorId(arg.getArgNumber());
573  OpOperand *t = &op->getOpOperand(tid);
574  assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
575  SmallVector<Value> args;
576  Value ptr = genSubscript(env, rewriter, t, args);
577  return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
578  }
579  } else if (Operation *def = e.getDefiningOp()) {
580  // Handle index computation.
581  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
582  return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
583  // When still defined in new body, recurse into operands.
584  if (def->getBlock() == block) {
585  rewriter.setInsertionPoint(def);
586  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
587  rewriter.modifyOpInPlace(def, [&]() {
588  def->setOperand(
589  i, relinkBranch(env, rewriter, block, def->getOperand(i)));
590  });
591  }
592  }
593  }
594  return e;
595 }
596 
597 /// Recursively generates tensor expression.
598 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
600  return Value();
601 
602  linalg::GenericOp op = env.op();
603  Location loc = op.getLoc();
604  const TensorExp &exp = env.exp(e);
605  const auto kind = exp.kind;
606  if (kind == TensorExp::Kind::kTensor)
607  return genTensorLoad(env, rewriter, e);
608  if (kind == TensorExp::Kind::kInvariant)
609  return genInvariantValue(env, e);
610  if (kind == TensorExp::Kind::kLoopVar)
611  return env.getLoopVar(exp.loop);
612 
613  if (kind == TensorExp::Kind::kReduce)
614  env.startCustomReduc(e); // enter custom
615 
616  // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
617  // based on the type of the other operand.
618  Value v0, v1;
621  v1 = genExp(env, rewriter, exp.children.e1);
622  v0 = constantZero(rewriter, loc, v1.getType());
625  v0 = genExp(env, rewriter, exp.children.e0);
626  v1 = constantZero(rewriter, loc, v0.getType());
627  } else {
628  v0 = genExp(env, rewriter, exp.children.e0);
629  v1 = genExp(env, rewriter, exp.children.e1);
630  }
631 
632  Value ee;
633  if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
634  // custom reduce did not receive a value
635  } else {
636  ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
637  if (ee &&
640  kind == TensorExp::Kind::kReduce ||
641  kind == TensorExp::Kind::kSelect)) {
642  OpBuilder::InsertionGuard guard(rewriter);
643  ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
644  }
645  }
646 
647  if (kind == TensorExp::Kind::kReduce)
648  env.endCustomReduc(); // exit custom
649 
650  if (kind == TensorExp::Kind::kSelect)
651  env.merger().setExprValue(e, v0); // Preserve value for later use.
652 
653  return ee;
654 }
655 
656 /// Hoists loop invariant tensor loads for which indices have been exhausted.
657 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
658  LoopId curr, bool isStart) {
660  return;
661  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
662  // Inspect tensor indices.
663  linalg::GenericOp op = env.op();
664  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
665  const auto map = op.getMatchingIndexingMap(&t);
666  const auto stt = getSparseTensorType(t.get());
667  const Level lvlRank = stt.getLvlRank();
668  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
669  bool isCurrentLoop = curr == 0; // for scalar tensors
670  for (Level l = 0; l < lvlRank; l++) {
671  const AffineExpr a = map.getResult(l);
672  if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
673  return; // still in play
674  }
675  // All exhausted at current level.
676  if (!isCurrentLoop)
677  return;
678  // Generate code for a scalarized reduction or invariant. Note that
679  // because custom reduction lhs may occur several times in the IR,
680  // we have a built-in safety for only initializing and wrapping-up
681  // the scalarized reduction once.
682  OpOperand *lhs = op.getDpsInitOperand(0);
683  if (lhs == &t) {
684  // Start or end a scalarized reduction.
685  if (isStart) {
686  if (env.isCustomReduc()) {
687  if (!env.isReduc())
688  env.startReduc(exp, env.getCustomRedId());
689  } else {
690  env.startReduc(exp, genTensorLoad(env, builder, exp));
691  }
692  if (env.hasSparseOutput())
694  constantI1(builder, env.op().getLoc(), false));
695  } else {
696  if (!env.isCustomReduc() || env.isReduc())
697  genTensorStore(env, builder, exp, env.endReduc());
698  if (env.hasSparseOutput())
699  env.endValidLexInsert();
700  }
701  } else {
702  // Start or end loop invariant hoisting of a tensor load.
703  if (isStart) {
704  env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
705  } else {
706  env.merger().clearExprValue(exp);
707  }
708  }
709  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
710  env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
711  env.exp(exp).kind != TensorExp::Kind::kSynZero) {
712  // Traverse into the binary operations. Note that we only hoist
713  // tensor loads, since subsequent MLIR/LLVM passes know how to
714  // deal with all other kinds of derived loop invariants.
715  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
716  env.startCustomReduc(exp); // enter custom
717  const ExprId e0 = env.exp(exp).children.e0;
718  const ExprId e1 = env.exp(exp).children.e1;
719  genInvariants(env, builder, e0, curr, isStart);
720  genInvariants(env, builder, e1, curr, isStart);
721  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
722  env.endCustomReduc(); // exit custom
723  }
724 }
725 
726 /// Generates an expanded access pattern in innermost dimension.
727 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
728  bool isStart) {
729  linalg::GenericOp op = env.op();
730  OpOperand *lhs = op.getDpsInitOperand(0);
731  if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
732  return; // not needed at current level
733  assert(!env.isReduc());
734  // Generate start or end of an expanded access pattern. Note that because
735  // an expansion does not rely on the ongoing contents of the sparse storage
736  // scheme, we can use the original tensor as incoming SSA value (which
737  // simplifies codegen a bit). If expansion on the actual contents is ever
738  // needed, we will need to use the SSA value in the insertion chain instead.
739  Value tensor = lhs->get();
740  Location loc = op.getLoc();
741  if (isStart) {
742  auto dynShape = {ShapedType::kDynamic};
743  Type etp = cast<ShapedType>(tensor.getType()).getElementType();
744  Type t1 = MemRefType::get(dynShape, etp);
745  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
746  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
747  Type t4 = builder.getIndexType();
748  auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
749  assert(r.getNumResults() == 4);
750  env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
751  r.getResult(3));
752  } else {
753  SmallVector<Value> indices;
754  for (LoopId i = 0; i < curr; i++)
755  indices.push_back(env.emitter().getLoopIV(i));
756  Value values = env.getExpandValues();
757  Value filled = env.getExpandFilled();
758  Value added = env.getExpandAdded();
759  Value count = env.getExpandCount();
760  Value chain = env.getInsertionChain();
761  Value compress = builder.create<CompressOp>(loc, values, filled, added,
762  count, chain, indices);
763  env.updateInsertionChain(compress);
764  env.endExpand();
765  }
766 }
767 
768 /// Returns parallelization strategy. Any implicit loop in the Linalg
769 /// operation that is marked "parallel" is a candidate. Whether it is actually
770 /// converted to a parallel operation depends on the requested strategy.
771 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
772  // Reject parallelization of sparse output.
773  if (env.hasSparseOutput())
774  return false;
775  // Parallel loops on tensor expansion can cause data races.
776  if (env.isExpand())
777  return false;
778  // Inspect strategy.
779  switch (env.options().parallelizationStrategy) {
781  return false;
783  return isOuter && !isSparse;
785  return isOuter;
787  return !isSparse;
789  return true;
790  }
791  llvm_unreachable("unexpected parallelization strategy");
792 }
793 
794 /// Whether or not the current loop being generated should be parallized (if
795 /// possible) according to the configuration.
796 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
797  ArrayRef<TensorLevel> tidLvls) {
798  linalg::GenericOp op = env.op();
799  auto iteratorTypes = op.getIteratorTypesArray();
800  bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
801  // Queries the LT based on the tensor and loop id, as requested by
802  // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
803  // should be consistent with the LT indexed by <TensorId, Level>.
804  const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
805  return lt.hasSparseSemantic();
806  });
807  return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
808 }
809 
810 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
811 /// can either be a for loop or while loop depending on whether there is at most
812 /// one sparse level in the list.
814  ArrayRef<TensorLevel> tidLvls,
815  bool tryParallel, bool needsUniv) {
816  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
817  // Construct while-loop with a parameter for each index.
818  return env.emitter().enterCoIterationOverTensorsAtLvls(
819  builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
820  });
821  assert(loop);
822  return loop;
823 }
824 
825 /// Generates a for-loop or a while-loop, depending on whether it implements
826 /// singleton iteration or co-iteration over the given conjunction.
827 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
828  bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
829  bool tryParallel = shouldTryParallize(env, curr, tidLvls);
830  return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
831 }
832 
833 /// Generates the induction structure for a while-loop.
834 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
835  bool needsUniv) {
836  Location loc = env.op().getLoc();
837  // Finalize each else branch of all if statements.
838  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
839  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
840  builder.getInsertionBlock()->getParentOp())) {
841  // Break on IfOp for slicing filtering.
842  if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
843  StringAttr::get(ifOp->getContext(), "slice"))
844  break;
845 
846  unsigned y = 0;
847  SmallVector<Value> yields;
848  if (env.isReduc()) {
849  yields.push_back(env.getReduc());
850  env.updateReduc(ifOp.getResult(y++));
851  if (env.isValidLexInsert()) {
852  yields.push_back(env.getValidLexInsert());
853  env.updateValidLexInsert(ifOp.getResult(y++));
854  }
855  }
856  if (env.isExpand()) {
857  yields.push_back(env.getExpandCount());
858  env.updateExpandCount(ifOp->getResult(y++));
859  }
860  if (env.getInsertionChain()) {
861  yields.push_back(env.getInsertionChain());
862  env.updateInsertionChain(ifOp->getResult(y++));
863  }
864  assert(y == yields.size());
865  builder.create<scf::YieldOp>(loc, yields);
866  builder.setInsertionPointAfter(ifOp);
867  }
868  }
869  // No need to set the insertion point here as LoopEmitter keeps track of the
870  // basic block where scf::Yield should be inserted.
871 }
872 
873 /// Generates a single if-statement within a while-loop.
874 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
875  LatPointId p) {
876  Location loc = env.op().getLoc();
877  SmallVector<Type> types;
878  Value cond;
880  p, /*simple=*/true,
881  [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
882  bool isIdxRed) {
883  if (isIdxRed) {
884  // Since there is no 1:1 mapping from loop to level (multiple loops
885  // are required to resolve one level with non-trivial index
886  // expression), we need to reconstruct the tensor level types if this
887  // loop requires index reduction condition.
888  assert(lvl.has_value() && isUndefLT(lt));
889  auto stt = getSparseTensorType(env.op().getInputs()[tid]);
890  lt = stt.getLvlType(*lvl);
891  }
892  assert(curr == env.merger().loop(b));
893  Value clause;
894  if (lt.hasSparseSemantic()) {
895  assert(lvl.has_value());
896  const Value crd = env.emitter().getCoord(tid, *lvl);
897  const Value lvar = env.getLoopVar(curr);
898  clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
899  crd, lvar);
900  } else {
901  assert(lt.hasDenseSemantic() || isUndefLT(lt));
902  clause = constantI1(builder, loc, true);
903  }
904  cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
905  });
906  if (env.isReduc()) {
907  types.push_back(env.getReduc().getType());
908  if (env.isValidLexInsert())
909  types.push_back(env.getValidLexInsert().getType());
910  }
911  if (env.isExpand())
912  types.push_back(builder.getIndexType());
913  if (env.getInsertionChain())
914  types.push_back(env.getInsertionChain().getType());
915  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
916  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
917  return ifOp;
918 }
919 
920 /// Generates end of true branch of if-statement within a while-loop.
921 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
922  Value redInput, Value cntInput, Value insInput,
923  Value validIns) {
924  SmallVector<Value> operands;
925  if (env.isReduc()) {
926  operands.push_back(env.getReduc());
927  env.updateReduc(redInput);
928  if (env.isValidLexInsert()) {
929  // Any overlapping indices during a reduction creates a valid lex insert.
930  operands.push_back(constantI1(builder, env.op().getLoc(), true));
931  env.updateValidLexInsert(validIns);
932  }
933  }
934  if (env.isExpand()) {
935  operands.push_back(env.getExpandCount());
936  env.updateExpandCount(cntInput);
937  }
938  if (env.getInsertionChain()) {
939  operands.push_back(env.getInsertionChain());
940  env.updateInsertionChain(insInput);
941  }
942  if (!operands.empty())
943  builder.create<scf::YieldOp>(env.op().getLoc(), operands);
944  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // Sparsifier synthesis methods (loop sequence).
949 //===----------------------------------------------------------------------===//
950 
952  CodegenEnv &env, LatPointId li, LoopId curr,
953  llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
954  const BitVector &simple = env.lat(li).simple;
955  const TensorId outTid = env.merger().getOutTensorID();
956  const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
957 
958  unsigned numloopCond = 0;
959  bool hasNonUnique = false;
961  li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
962  LevelType lt, bool isIdxReduc) {
963  if (simple[b]) {
964  if (isIdxReduc) {
965  callback(env.makeTensorLevel(tid, *lvl), nullptr);
966  numloopCond++;
967  return;
968  }
969  if (isUndefLT(lt)) {
970  // An undefined lt in the lattices, we probably mean to
971  // generate a dense loop according to the synthetic tensor (for
972  // invariants and sparse output tensor).
973  if (env.merger().getSynTensorID() == tid) {
974  // Coiterating with an invariant
975  // e.g., out = prod(in[i][j] op invariant);
976  // or a broadcast
977  // e.g., out[i][j] = in[i] (j is undef for input)
978  //
979  // The level of the synthetic tensor is the current loop depth;
980  // the rank of the synthetic tensor equals to number of loops.
981  assert(curr == env.getCurrentDepth());
982  lvl = curr;
983  } else if (!lvl) {
984  // Skips invalid lvl (e.g., when this is a zero ranked tensor).
985  return;
986  }
987  }
988  hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
989  callback(env.makeTensorLevel(tid, *lvl), nullptr);
990  numloopCond++;
991  } else if (lt.hasDenseSemantic() || isIdxReduc) {
992  callback(env.makeTensorLevel(tid, *lvl), nullptr);
993  } else {
994  assert(isUndefLT(lt));
995  linalg::GenericOp op = env.op();
996  if (tid >= op.getNumDpsInputs())
997  // We only handle affine expression on input tensors (for now).
998  return;
999  OpOperand *operand = &op->getOpOperand(tid);
1000  const auto stt = getSparseTensorType(operand->get());
1001  // Non-annotated dense tensors requires no special handling.
1002  if (!stt.hasEncoding())
1003  return;
1004 
1005  ArrayRef<AffineExpr> affines =
1006  op.getMatchingIndexingMap(operand).getResults();
1007  const Level lvlRank = stt.getLvlRank();
1008  assert(affines.size() == static_cast<size_t>(lvlRank));
1009  for (Level l = 0; l < lvlRank; l++) {
1010  AffineExpr exp = affines[l];
1011  // Skip simple affine expression and non-dense levels (which
1012  // have their own filter loop).
1013  LevelType lt = stt.getLvlType(l);
1014  if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1015  continue;
1016 
1017  // Constant affine expression are handled in genLoop.
1018  if (!isa<AffineConstantExpr>(exp)) {
1019  bool isCurrentLoop = false;
1020  assert(curr == env.getCurrentDepth());
1021  if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1022  isCurrentLoop) {
1023  // If the compound affine is invariant and we are right at the
1024  // level. We need to generate the address according to the
1025  // affine expression. This is also the best place we can do it
1026  // to avoid putting it inside inner loops.
1027  callback(env.makeTensorLevel(tid, l), exp);
1028  }
1029  }
1030  }
1031  }
1032  });
1033 
1034  if (isDenseLT(env.lt(outTid, curr))) {
1035  auto stt = getSparseTensorType(env.op().getOutputs().front());
1036  // Note that we generate dense indices of the output tensor unconditionally,
1037  // since they may not appear in the lattice, but may be needed for
1038  // linearized env.
1039  // TODO: we should avoid introducing corner cases for all-dense sparse
1040  // tensors.
1041  if (stt.hasEncoding() && stt.isAllDense())
1042  callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1043  }
1044 
1045  if (numloopCond == 0) {
1046  // Corner cases where the loop bound is defined by a *unused* operand, in
1047  // this case, we just generate a dense "fake" loop by iterating over the
1048  // synthetic tensor.
1049  callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1050  numloopCond++;
1051  }
1052  // If we just need to one loop conditions and the conditions is not imposed on
1053  // non-unique level, the loop can be generated by a for loop.
1054  return numloopCond == 1 && !hasNonUnique;
1055 }
1056 
1057 /// Starts a loop sequence at given level. Returns true if
1058 /// the universal loop index must be maintained at this level.
1059 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1060  LoopId curr, LatSetId lts) {
1061  assert(!env.getLoopVar(curr));
1062  // Emit invariants at this loop sequence level.
1063  genInvariants(env, builder, exp, curr, /*isStart=*/true);
1064  // Emit access pattern expansion for sparse tensor output.
1065  genExpand(env, builder, curr, /*isStart=*/true);
1066  // Emit further initialization at this loop sequence level.
1067  const LatPointId l0 = env.set(lts)[0];
1068 
1069  SmallVector<TensorLevel> tidLvls;
1070  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1071  // TODO: remove this! The same tensor level might be added for multiple
1072  // times due to the special handling for all-dense "sparse" output tensor
1073  // (see L1038).
1074  if (llvm::find(tidLvls, tl) != tidLvls.end())
1075  return;
1076  tidLvls.emplace_back(tl);
1077  });
1078 
1079  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1080 
1081  // Maintain the universal index only if it is actually
1082  // consumed by a subsequent lattice point.
1083  for (const LatPointId li : env.set(lts).drop_front())
1084  if (!env.merger().hasAnySparse(env.lat(li).simple))
1085  return true;
1086 
1087  return false;
1088 }
1089 
1090 // Generates dense affine address for encoding.
1092  OpBuilder &builder, TensorId tid,
1093  Level startLvl) {
1094  // TODO: Handle affine expression on output tensor.
1095  linalg::GenericOp op = env.op();
1096  assert(tid < op.getNumDpsInputs());
1097  OpOperand *input = op.getDpsInputOperands()[tid];
1098  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1099  const auto enc = getSparseTensorEncoding(input->get().getType());
1100  if (enc) {
1101  const Location loc = op.getLoc();
1102  const TensorId tid = env.makeTensorId(input->getOperandNumber());
1103  const Level lvlRank = enc.getLvlRank();
1104  assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1105  for (Level l = startLvl; l < lvlRank; l++) {
1106  AffineExpr lvlExpr = lvlExprs[l];
1107  if (enc.getLvlType(l).hasDenseSemantic() &&
1108  isa<AffineConstantExpr>(lvlExpr))
1110  builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1111  else
1112  return; // break on first non-dense non-constant level
1113  }
1114  }
1115 }
1116 
1117 // We can generate address for constant affine expression before any loops
1118 // starting from the first level as they do not depend on anything.
1119 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1120 // levels can be determined before loops.
1122  RewriterBase &rewriter) {
1123  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1124  genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1125 }
1126 
1127 /// Returns true if the lattice bit can be iterated by a for loop.
1129  CodegenEnv &env, LatPointId li, LoopId curr,
1131  SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1132  return getAllTidLvlsInLatPoints(env, li, curr,
1133  [&](TensorLevel tl, AffineExpr exp) {
1134  if (exp)
1135  affineTidLvls.emplace_back(tl, exp);
1136  else
1137  tidLvls.emplace_back(tl);
1138  });
1139 }
1140 
1141 /// Starts a single loop in current sequence.
1142 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1143  OpBuilder &builder, LoopId curr,
1144  LatPointId li, bool needsUniv) {
1145  // The set of tensors + lvls to generate loops on
1146  SmallVector<TensorLevel> tidLvls;
1147 
1148  // The set of dense tensors with non-trivial affine expression that just
1149  // becomes invariant and the address are generated at the current level.
1151  bool isSingleCond =
1152  translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1153 
1154  // Emit the for/while-loop control.
1155  Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
1156  Location loc = env.op().getLoc();
1157  for (auto [tidLvl, exp] : affineTidLvls) {
1158  env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1159  }
1160 
1161  // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1162  // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1163  // on constant affines expression may now be determined.
1164  auto allTidLvls =
1165  llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1166  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1167  if (tid != env.merger().getOutTensorID() &&
1168  tid != env.merger().getSynTensorID())
1169  genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1170  }
1171 
1172  return std::make_pair(loop, isSingleCond);
1173 }
1174 
1175 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1176 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1177  LatPointId li, bool needsUniv, bool isSingleCond) {
1178  // Either a for-loop or a while-loop that iterates over a slice.
1179  if (isSingleCond) {
1180  // Any iteration creates a valid lex insert.
1181  if (env.isReduc() && env.isValidLexInsert())
1182  env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1183  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1184  // End a while-loop.
1185  finalizeWhileOp(env, rewriter, needsUniv);
1186  } else {
1187  needsUniv = false;
1188  }
1189  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1190  env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1191  return std::nullopt;
1192  });
1193  return needsUniv;
1194 }
1195 
1196 /// Ends a loop sequence at given level.
1197 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1198  unsigned at) {
1199  assert(!env.getLoopVar(at));
1200  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1201  // Unmark bookkeeping of invariants and loop index.
1202  genInvariants(env, builder, exp, at, /*isStart=*/false);
1203  // Finalize access pattern expansion for sparse tensor output.
1204  genExpand(env, builder, at, /*isStart=*/false);
1205 }
1206 
1207 /// Recursively generates code while computing iteration lattices in order
1208 /// to manage the complexity of implementing co-iteration over unions
1209 /// and intersections of sparse iterations spaces.
1210 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1211  LoopId curr) {
1212  assert(curr == env.getCurrentDepth());
1213 
1214  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1215  if (curr == env.getLoopNum()) {
1216  Value rhs = genExp(env, rewriter, exp);
1217  genTensorStore(env, rewriter, exp, rhs);
1218  return;
1219  }
1220 
1221  // Construct iteration lattices for current loop index.
1222  const LatSetId lts =
1223  env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1224 
1225  // Start a loop sequence.
1226  bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1227 
1228  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1229  // We cannot change this to `for (const LatPointId li : env.set(lts))`
1230  // because the loop body causes data-movement which invalidates
1231  // the iterator.
1232  const unsigned lsize = env.set(lts).size();
1233  for (unsigned i = 0; i < lsize; i++) {
1234  const LatPointId li = env.set(lts)[i];
1235  // Start a loop.
1236  auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
1237 
1238  // Visit all lattices points with Li >= Lj to generate the
1239  // loop-body, possibly with if statements for coiteration.
1240  Value redInput = env.getReduc();
1241  Value cntInput = env.getExpandCount();
1242  Value insInput = env.getInsertionChain();
1243  Value validIns = env.getValidLexInsert();
1244  // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1245  // because the loop body causes data-movement which invalidates the
1246  // iterator.
1247  for (unsigned j = 0; j < lsize; j++) {
1248  const LatPointId lj = env.set(lts)[j];
1249  const ExprId ej = env.lat(lj).exp;
1250  if (li == lj || env.merger().latGT(li, lj)) {
1251  // Recurse into body of each branch.
1252  if (!isSingleCond) {
1253  scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1254  genStmt(env, rewriter, ej, curr + 1);
1255  endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1256  } else {
1257  genStmt(env, rewriter, ej, curr + 1);
1258  }
1259  }
1260  }
1261 
1262  // End a loop.
1263  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1264  }
1265 
1266  // End a loop sequence.
1267  endLoopSeq(env, rewriter, exp, curr);
1268  assert(curr == env.getCurrentDepth());
1269 }
1270 
1271 /// Converts the result computed by the sparse kernel into the required form.
1272 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1273  linalg::GenericOp op = env.op();
1274  OpOperand *lhs = op.getDpsInitOperand(0);
1275  Value tensor = lhs->get();
1276  Type resType = tensor.getType();
1277  if (getSparseTensorEncoding(resType)) {
1278  // The sparse tensor rematerializes from the original sparse tensor's
1279  // underlying sparse storage format. For an insertion chain, the
1280  // tensor materializes from the chain with 'hasInserts' enabled.
1281  bool hasInserts = false;
1282  if (Value chain = env.getInsertionChain()) {
1283  hasInserts = true;
1284  tensor = chain;
1285  }
1286  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1287  } else {
1288  // To rematerialize an non-annotated tensor, simply load it
1289  // from the bufferized value.
1290  Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1291  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1292  }
1293 }
1294 
1295 //===----------------------------------------------------------------------===//
1296 // Sparsifier rewriting methods.
1297 //===----------------------------------------------------------------------===//
1298 
1299 namespace {
1300 
1301 /// Sparse rewriting rule for generic Lingalg operation.
1302 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1303 public:
1304  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1305  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1306 
1307  LogicalResult matchAndRewrite(linalg::GenericOp op,
1308  PatternRewriter &rewriter) const override {
1309  // Only accept single output operations with pure tensor semantics.
1310  if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1311  return failure();
1312 
1313  // Only accept trivial affine indices.
1315  return failure();
1316 
1317  // Only accept scheduled loops.
1318  if (!op->hasAttr("sorted")) {
1319  return rewriter.notifyMatchFailure(
1320  op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1321  "before sparsification.");
1322  }
1323 
1324  // Must have been demapped as well if the generic op is sorted.
1326 
1327  // Sets up a code generation environment.
1328  const unsigned numTensors = op->getNumOperands();
1329  const unsigned numLoops = op.getNumLoops();
1330  bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1331  // If we have indexing map like (d0) -> (0, d0), there might be more
1332  // levels then loops because of the constant index, that means we can not
1333  // use numLoops as the upper bound for ranks of all tensors.
1334  // TODO: Constant indices are currently not support on sparse tensor, but
1335  // are allowed in non-annotated dense tensor. Support it, it would be
1336  // required for sparse tensor slice rank reducing too.
1337  Level maxLvlRank = 0;
1338  for (auto operand : op.getOperands()) {
1339  if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1340  maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1341  }
1342  }
1343 
1344  // Detects sparse annotations and translates the per-level sparsity
1345  // information for all tensors to loop indices in the kernel.
1346  CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1347  if (!findSparseAnnotations(env, needIdxRed))
1348  return failure();
1349 
1350  // Only standard reduction operations (add, sub, or, xor) that can be
1351  // sparsified by merely reducing the stored values are admissible. More
1352  // elaborate reduction operations (such as mul, and, min, max) would need
1353  // to know whether implicit zeros occur as well. They can still be
1354  // implemented with a custom reduction operation, accepted here as well.
1355  if (op.getNumReductionLoops() > 0) {
1356  Operation *yield = op.getRegion().front().getTerminator();
1357  assert(isa<linalg::YieldOp>(yield));
1358  Operation *redop = yield->getOperand(0).getDefiningOp();
1359  if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1360  !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1361  !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1362  !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1363  !isa<ReduceOp>(redop)) {
1364  return failure();
1365  }
1366  }
1367 
1368  // Constructs the tensor expressions tree from `op`, returns failure if the
1369  // tree can not be built or the tensor expression is inadmissible.
1370  if (failed(env.initTensorExp()))
1371  return failure();
1372 
1373  // Recursively generates code if admissible.
1374  env.startEmit(options.sparseEmitStrategy);
1375  genBuffers(env, rewriter);
1376  // TODO: Constant affine expression should be handled differently when using
1377  // slice-based codegen, it does not matter now because we already reject the
1378  // constant expression at an earlier stage.
1379  genInitConstantDenseAddress(env, rewriter);
1380  genStmt(env, rewriter, env.getExprId(), 0);
1381  genResult(env, rewriter);
1382  return success();
1383  }
1384 
1385 private:
1386  /// Options to control sparse code generation.
1388 };
1389 
1390 } // namespace
1391 
1392 /// Populates the given patterns list with rewriting rules required for
1393 /// the sparsification of linear algebra operations.
1395  RewritePatternSet &patterns, const SparsificationOptions &options) {
1396  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1397 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, bool needsUniv)
Starts a single loop in current sequence.
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static Operation * genCoIteration(CodegenEnv &env, OpBuilder &builder, ArrayRef< TensorLevel > tidLvls, bool tryParallel, bool needsUniv)
Emit a loop to coiterate over the list of tensor levels.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:228
void updateValidLexInsert(Value val)
Definition: CodegenEnv.cpp:256
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:103
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:79
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
Definition: CodegenEnv.cpp:200
unsigned getCurrentDepth() const
Definition: CodegenEnv.h:111
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:103
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:91
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:78
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:68
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:205
unsigned getLoopNum() const
Definition: CodegenEnv.h:85
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:195
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:266
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Definition: CodegenEnv.cpp:187
void startEmit(SparseEmitStrategy emitStrategy)
Definition: CodegenEnv.cpp:62
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:107
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:77
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:214
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:129
void startValidLexInsert(Value val)
Definition: CodegenEnv.cpp:251
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:71
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:80
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:234
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
const std::vector< Value > & getValBuffer() const
Definition: LoopEmitter.h:232
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
Definition: LoopEmitter.h:171
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
Definition: LoopEmitter.h:223
auto getLoopIVsRange() const
Get the range of values for all induction variables.
Definition: LoopEmitter.h:157
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
Definition: LoopEmitter.h:229
A class to handle all iteration lattice operations.
Definition: Merger.h:224
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:415
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:935
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:248
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:425
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:441
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:430
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:471
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:669
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:566
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:366
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:502
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:347
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:362
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:398
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1537
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:558
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:480
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
bool isUniqueLT(LevelType lt)
Definition: Enums.h:424
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
unsigned TensorLevel
Definition: LoopEmitter.h:26
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:359
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool isUndefLT(LevelType lt)
Definition: Enums.h:408
bool isDenseLT(LevelType lt)
Definition: Enums.h:409
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options for the Sparsification pass.
Definition: Passes.h:91
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:104
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:217
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:214
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition: Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.