MLIR  21.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenEnv.h"
14 #include "Utils/CodegenUtils.h"
15 #include "Utils/LoopEmitter.h"
16 
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TensorEncoding.h"
36 #include "llvm/ADT/SmallBitVector.h"
37 
38 #include <optional>
39 
40 using namespace mlir;
41 using namespace mlir::sparse_tensor;
42 
43 //===----------------------------------------------------------------------===//
44 // Sparsifier analysis methods.
45 //===----------------------------------------------------------------------===//
46 
47 /// Returns true iff affine expression is invariant. Sets the
48 /// parameter `isCurrentLoop` when expression just became invariant.
49 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50  switch (a.getKind()) {
51  case AffineExprKind::DimId: {
52  const LoopId i = cast<AffineDimExpr>(a).getPosition();
53  if (i + 1 == curr) {
54  isCurrentLoop = true;
55  return true; // becomes invariant at current loop
56  }
57  return i < curr; // invariant when already generated
58  }
60  case AffineExprKind::Mul: {
61  auto binOp = cast<AffineBinaryOpExpr>(a);
62  return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
63  isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
64  }
65  default: {
66  assert(isa<AffineConstantExpr>(a));
67  return true;
68  }
69  }
70 }
71 
72 /// Helper method to inspect affine expressions. Rejects cases where the
73 /// same index is used more than once. Also rejects compound affine
74 /// expressions in sparse dimensions.
75 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76  LevelType lt, bool setLvlFormat = true) {
77  switch (a.getKind()) {
78  case AffineExprKind::DimId: {
79  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
80  if (!isUndefLT(merger.getLvlType(tid, idx)))
81  return false; // used more than once
82  if (setLvlFormat)
83  merger.setLevelAndType(tid, idx, lvl, lt);
84  return true;
85  }
89  assert(lt.hasDenseSemantic());
90  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
91  // We do not set dim level format for affine expression like d0 + d1 on
92  // either loop index at d0 or d1. We continue the recursion merely to
93  // check whether current affine is admissible or not.
94  return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
95  findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
96  }
97  // Falls through when it is a constant Affine
98  return true;
99  }
100  default:
101  return false;
102  }
103 }
104 
105 /// Helper method to inspect affine expressions for index variable reduction
106 /// based codegen. It finds the dependent index set for all tensor levels in the
107 /// current expression we are generating.
108 ///
109 /// For example, when handling A[i+j][j+k], we build the two way mapping in
110 /// merger between (tensor, level) pairs and their dependent index variable set:
111 /// A_0 <=> [i, j] and A_1 <=> [j, k]
112 ///
113 /// It rejects cases (returns false)
114 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
115 /// 2nd, when multiplication is used in the non-trivial index expression.
116 /// 3rd, when a constant operand is used in the non-trivial index expression.
117 ///
118 /// TODO: constant should be easy to handle.
119 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120  AffineExpr a, LevelType lt, bool isSubExp = false,
121  int64_t coefficient = 1) {
122  switch (a.getKind()) {
123  case AffineExprKind::DimId: {
124  // Only allow positive coefficients on AffineDimExpr.
125  if (coefficient <= 0)
126  return false;
127 
128  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
129  if (!isUndefLT(merger.getLvlType(tensor, idx)))
130  return false; // used more than once, e.g., A[i][i]
131 
132  // TODO: Generalizes the following two cases. A[i] (with trivial index
133  // expression) can be treated as a special affine index expression. We do
134  // not necessarily need to differentiate them.
135  if (!isSubExp) {
136  assert(coefficient == 1);
137  merger.setLevelAndType(tensor, idx, lvl, lt);
138  }
139 
140  if (isSubExp) {
141  // The current loops appears in more than one affine expressions on the
142  // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143  // used twice.
144  if (merger.hasDependentLvl(idx, tensor)) {
145  // TODO: This can be supported by coiterate slices if the loop idx is
146  // appeared on affine index for different tensor, or take slice on
147  // multiple dimensions when it is on the same tensor.
148  // E.g.,
149  // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150  // d0_1 = getNextSliceOffset t0 along lvl0
151  // d0_2 = getNextSliceOffset t1 along lvl0
152  // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153  // else increase min(d0_1, d0_2).
154  return false;
155  }
156  merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
157  }
158  return true;
159  }
161  case AffineExprKind::Mul: {
162  // TODO: Support index expression like `2 * d0`, we now only support more
163  // complicated cases like `2 * d0 + d1`.
164  if (!isSubExp)
165  return false;
166 
167  // TODO: Support Constant AffineExp for slice-based codegen
168  if (isa<AffineConstantExpr>(a))
169  llvm_unreachable("Not yet implemented");
170 
171  auto binOp = cast<AffineBinaryOpExpr>(a);
172  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173  if (isa<AffineConstantExpr>(rhs))
174  std::swap(lhs, rhs);
175  // Must be in form of `constant * d`.
176  assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177  int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
178  return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
179  }
180  case AffineExprKind::Add: {
181  auto binOp = cast<AffineBinaryOpExpr>(a);
182  return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
183  findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
184  }
185  default:
186  return false;
187  }
188 }
189 
190 /// Gets the total number of compound affine expressions in the
191 /// `getMatchingIndexingMap` for the given tensor. For the following inputs:
192 ///
193 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194 ///
195 /// Returns 1 (because the first level is compressed and its corresponding
196 /// indexing-expression is `d0 + d1`)
198  Value tensor) {
199  // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200  // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201  // However, we don't need to handle `StorageSpecifierType`, so we
202  // can use `SparseTensorType` once we guard against non-tensors.
203  const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204  if (!rtp)
205  return 0;
206  const SparseTensorType stt(rtp);
207 
208  const Level lvlRank = stt.getLvlRank();
209  const auto exprs = map.getResults();
210  assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211  "AffineMap does not have dimension-rank many results");
212  unsigned num = 0;
213  for (Level l = 0; l < lvlRank; l++) {
214  if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215  num++;
216  }
217  return num;
218 }
219 
220 /// Gets the total number of sparse levels with compound affine
221 /// expressions, summed over all operands of the `GenericOp`.
222 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223  unsigned num = 0;
224  for (OpOperand &t : op->getOpOperands())
225  num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226  t.get());
227  return num;
228 }
229 
230 // Returns true iff output has nontrivial affine indices.
231 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232  OpOperand *out = op.getDpsInitOperand(0);
233  if (getSparseTensorType(out->get()).isAllDense())
234  return false;
235  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236  out->get());
237 }
238 
239 /// Helper method to inspect sparse encodings in the tensor types.
240 /// Fills the per-dimension sparsity information for all tensors.
241 /// Returns true if the sparse annotations and affine subscript
242 /// expressions of all tensors are admissible. Returns false if
243 /// no annotations are found or inadmissible constructs occur.
244 /// We currently support two different ways to handle non-trivial index
245 /// expression on sparse tensors, and they accept different affine expressions.
246 /// When using dependent index reducton-based approach, it currently only
247 /// supports affine addition index expression.
248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249  bool annotated = false;
250  for (OpOperand &t : env.op()->getOpOperands()) {
251  const TensorId tid = env.makeTensorId(t.getOperandNumber());
252  const auto map = env.op().getMatchingIndexingMap(&t);
253  const auto enc = getSparseTensorEncoding(t.get().getType());
254  if (enc)
255  annotated = true;
256  const Level lvlRank = map.getNumResults();
257  assert(!enc || lvlRank == enc.getLvlRank());
258  assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259  // We only need to do index reduction if there is at least one
260  // non-trivial index expression on sparse levels. If all non-trivial
261  // index expression is on dense levels, we can efficiently rely on
262  // the random access to locate the element.
263  bool needIdxReduc =
264  enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265  // If then current tensor being inspected requires affine index, it need
266  // to be sliced.
267  for (Level l = 0; l < lvlRank; l++) {
268  const AffineExpr a = map.getResult(l);
269  const LevelType lt = enc.getLvlType(l);
270  if (idxReducBased && needIdxReduc) {
271  if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272  return false; // inadmissible affine expression
273  } else {
274  if (!findAffine(env.merger(), tid, l, a, lt))
275  return false; // inadmissible affine expression
276  }
277  }
278  }
279  return annotated;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Sparsifier synthesis methods (statements and expressions).
284 //===----------------------------------------------------------------------===//
285 
286 /// Local bufferization of all dense and sparse data structures.
287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288  linalg::GenericOp op = env.op();
289  Location loc = op.getLoc();
290  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291 
292  SmallVector<Range, 4> loopRange =
293  llvm::cast<linalg::LinalgOp>(op.getOperation())
294  .createLoopRanges(builder, loc);
295 
297  builder, loc,
298  /// Generates buffer for the output tensor.
299  /// Note that all sparse kernels assume that when all elements are written
300  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301  /// to all zeroes and only nonzeroes values are computed and written out.
302  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303  /// for the updates and no assumption on the original contents of the
304  /// output buffer is necessary.
305  [&op](OpBuilder &builder, Location loc, Value memref,
306  Value tensor) -> Value {
307  // Must not be a sparse tensor.
308  assert(!getSparseTensorEncoding(tensor.getType()));
309  // Two output tensor references should point to the same object.
310  OpOperand *lhs = op.getDpsInitOperand(0);
311  assert(lhs->get() == tensor);
312  // An output tensor can simply materialize from the buffer of the tensor
313  // that appears in the outs() clause. For updates, this has the
314  // advantage that only the nonzero value are involved in the
315  // computation, keeping the operation O(nnz). In all other cases, we are
316  // forced to zero out the buffer to enforce the assumption above, which
317  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318  // O(nnz) for matrices).
319  // TODO: use better analysis to avoid zeroing out the buffer?
320  bool isInit = op.isInitTensor(lhs);
321  Value init = memref;
322  if (!isInit) {
323  Value zero = constantZero(builder, loc,
324  getElementTypeOrSelf(tensor.getType()));
325  builder.create<linalg::FillOp>(loc, ValueRange{zero},
326  ValueRange{init});
327  }
328  return init;
329  },
330  [&loopRange](OpBuilder &b, Location loc, Level l) {
331  assert(l < loopRange.size());
332  return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333  });
334 }
335 
336 /// Generates index for load/store on sparse tensor.
337 static Value genIndex(CodegenEnv &env, OpOperand *t) {
338  const auto map = env.op().getMatchingIndexingMap(t);
339  const auto stt = getSparseTensorType(t->get());
340  const Level lvlRank = stt.getLvlRank();
341  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342  const AffineExpr a = map.getResult(lvlRank - 1);
343  assert(a.getKind() == AffineExprKind::DimId);
344  const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345  return env.getLoopVar(idx);
346 }
347 
348 /// Generates subscript for load/store on a dense or sparse tensor.
349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350  SmallVectorImpl<Value> &args) {
351  const Location loc = env.op().getLoc();
352  const TensorId tid = env.makeTensorId(t->getOperandNumber());
353  const auto map = env.op().getMatchingIndexingMap(t);
354  const auto stt = getSparseTensorType(t->get());
355  if (stt.hasEncoding()) {
356  // For sparse tensors we only push the last-level's position onto `args`.
357  const auto pos = env.emitter().getValPosits(tid);
358  assert(!pos.empty());
359  args.append(pos);
360  // Simply returns the tensor to extract value using iterators.
362  return t->get();
363  } else {
364  // For dense tensors we push all level's coordinates onto `args`.
365  const Level lvlRank = stt.getLvlRank();
366  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
367  for (Level l = 0; l < lvlRank; l++) {
368  const auto lvlExpr = map.getResult(l);
369  const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
370  args.push_back(lvlCrd);
371  }
372  }
373  return env.emitter().getValBuffer()[tid];
374 }
375 
376 /// Generates insertion code to implement dynamic tensor load.
378  OpOperand *t) {
379  linalg::GenericOp op = env.op();
380  Location loc = op.getLoc();
381  // Direct lexicographic coordinate order, tensor loads as zero.
382  if (!env.isExpand()) {
383  Type tp = getElementTypeOrSelf(t->get().getType());
384  return constantZero(builder, loc, tp);
385  }
386  // Load from expanded access pattern.
387  Value index = genIndex(env, t);
388  return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
389 }
390 
391 /// Generates insertion code to implement dynamic tensor load for reduction.
393  OpOperand *t) {
394  linalg::GenericOp op = env.op();
395  Location loc = op.getLoc();
396  Value identity = env.getCustomRedId();
397  // Direct lexicographic coordinate order, tensor loads as identity.
398  if (!env.isExpand())
399  return identity;
400  // Load from expanded access pattern if filled, identity otherwise.
401  Value values = env.getExpandValues();
402  Value filled = env.getExpandFilled();
403  Value index = genIndex(env, t);
404  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
405  Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
406  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
407 }
408 
409 static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
410  Value sparseOut, ValueRange ivs, Value v) {
411  scf::IfOp condInsert =
412  builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
413  // True branch.
414  builder.setInsertionPointToStart(condInsert.thenBlock());
415  Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
416  builder.create<scf::YieldOp>(loc, res);
417  // False branch.
418  builder.setInsertionPointToStart(condInsert.elseBlock());
419  builder.create<scf::YieldOp>(loc, sparseOut);
420  // Value assignment.
421  builder.setInsertionPointAfter(condInsert);
422  return condInsert.getResult(0);
423 }
424 
425 /// Generates insertion code to implement dynamic tensor store.
426 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
427  Value rhs) {
428  linalg::GenericOp op = env.op();
429  Location loc = op.getLoc();
430  // Direct insertion in lexicographic coordinate order.
431  if (!env.isExpand()) {
432  const LoopId numLoops = op.getRank(t);
433  // Retrieves the first `numLoop` induction variables.
434  SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
435  env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
436  Value chain = env.getInsertionChain();
437  if (env.isValidLexInsert()) {
438  // Generates runtime check for a valid lex during reduction,
439  // to avoid inserting the identity value for empty reductions.
440  // if (validLexInsert) then
441  // insert(rhs) into chain
442  // return updated chain
443  // else
444  // return unmodified chain
445  Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
446  chain, ivs, rhs);
447  env.updateInsertionChain(out);
448  } else {
449  Value sparseOut;
450  if (!hasAnySparseType(env.op().getInputs().getTypes())) {
451  // This is an all-dense -> sparse kernel, test rhs != 0 before
452  // insertion.
453  Value nz = genIsNonzero(builder, loc, rhs);
454  sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
455  } else {
456  sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
457  }
458  // Generates regular insertion chain.
459  env.updateInsertionChain(sparseOut);
460  }
461  return;
462  }
463  // Generates insertion code along expanded access pattern.
464  // if (!expFilled[i]) then
465  // expFilled[i] = true
466  // expAdded[inserts++] = i
467  // endif
468  // values[i] = rhs
469  Value values = env.getExpandValues();
470  Value filled = env.getExpandFilled();
471  Value added = env.getExpandAdded();
472  Value count = env.getExpandCount();
473  Value index = genIndex(env, t);
474  Value fval = constantI1(builder, loc, false);
475  Value tval = constantI1(builder, loc, true);
476  // If statement.
477  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
478  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
479  isFilled, fval);
480  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
481  /*else=*/true);
482  // True branch.
483  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
484  builder.create<memref::StoreOp>(loc, tval, filled, index);
485  builder.create<memref::StoreOp>(loc, index, added, count);
486  Value one = constantIndex(builder, loc, 1);
487  Value add = builder.create<arith::AddIOp>(loc, count, one);
488  builder.create<scf::YieldOp>(loc, add);
489  // False branch.
490  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
491  builder.create<scf::YieldOp>(loc, count);
492  builder.setInsertionPointAfter(ifOp);
493  // Value assignment.
494  env.updateExpandCount(ifOp.getResult(0));
495  builder.create<memref::StoreOp>(loc, rhs, values, index);
496 }
497 
498 /// Generates a load on a dense or sparse tensor.
499 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
500  // Test if the load was hoisted to a higher loop nest.
501  Value val = env.exp(exp).val;
502  if (val)
503  return val;
504  // Get tensor operand.
505  linalg::GenericOp op = env.op();
506  Location loc = op.getLoc();
507  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
508  // Fold binary-valued tensor into explicit value.
509  const auto stt = getSparseTensorType(t->get());
510  if (auto explVal = stt.getExplicitVal())
511  return genValFromAttr(builder, loc, explVal);
512  // Load during insertion.
513  if (env.isSparseOutput(t)) {
514  if (env.isCustomReduc())
515  return genInsertionLoadReduce(env, builder, t);
516  return genInsertionLoad(env, builder, t);
517  }
518 
519  // Actual load.
520  SmallVector<Value> args;
521  Value ptr = genSubscript(env, builder, t, args);
522  if (llvm::isa<TensorType>(ptr.getType())) {
523  assert(env.options().sparseEmitStrategy ==
525  return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args));
526  }
527  return builder.create<memref::LoadOp>(loc, ptr, args);
528 }
529 
530 /// Generates a store on a dense or sparse tensor.
531 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
532  Value rhs) {
533  // Only unary and binary are allowed to return an uninitialized rhs
534  // to indicate missing output. Or otherwise a custom reduction that
535  // received no value to accumulate.
536  if (!rhs) {
537  assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
538  env.exp(exp).kind == TensorExp::Kind::kBinary ||
539  env.exp(exp).kind == TensorExp::Kind::kReduce);
540  return;
541  }
542  // Test if this is a scalarized reduction.
543  if (env.isReduc()) {
544  env.updateReduc(rhs);
545  return;
546  }
547  // Regular store.
548  linalg::GenericOp op = env.op();
549  Location loc = op.getLoc();
550  OpOperand *t = op.getDpsInitOperand(0);
551  if (!env.isSparseOutput(t)) {
552  SmallVector<Value> args;
553  Value ptr = genSubscript(env, builder, t, args);
554  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
555  return;
556  }
557  // Store during sparse insertion.
558  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
559  genInsertionStore(env, builder, t, rhs);
560  return;
561  }
562  // Select operation insertion.
563  Value chain = env.getInsertionChain();
564  scf::IfOp ifOp =
565  builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
566  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
567  // Existing value was preserved to be used here.
568  assert(env.exp(exp).val);
569  Value v0 = env.exp(exp).val;
570  genInsertionStore(env, builder, t, v0);
571  env.merger().clearExprValue(exp);
572  // Yield modified insertion chain along true branch.
573  Value mchain = env.getInsertionChain();
574  builder.create<scf::YieldOp>(op.getLoc(), mchain);
575  // Yield original insertion chain along false branch.
576  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
577  builder.create<scf::YieldOp>(loc, chain);
578  // Done with if statement.
579  env.updateInsertionChain(ifOp->getResult(0));
580  builder.setInsertionPointAfter(ifOp);
581 }
582 
583 /// Generates an invariant value.
584 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
585  return env.exp(exp).val;
586 }
587 
588 /// Semi-ring branches are simply inlined by the sparsifier. Prior
589 /// analysis has verified that all computations are "local" to the inlined
590 /// branch or otherwise invariantly defined outside the loop nest, with the
591 /// exception of index computations, which need to be relinked to actual
592 /// inlined cloned code.
593 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
594  Value e) {
595  if (auto arg = dyn_cast<BlockArgument>(e)) {
596  // Direct arguments of the original linalg op must be converted
597  // into dense tensor loads. Note that we should not encounter
598  // anything else. This needs to be verified by semi-ring ops.
599  linalg::GenericOp op = env.op();
600  if (arg.getOwner()->getParentOp() == op) {
601  const TensorId tid = env.makeTensorId(arg.getArgNumber());
602  OpOperand *t = &op->getOpOperand(tid);
603  assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
604  SmallVector<Value> args;
605  Value ptr = genSubscript(env, rewriter, t, args);
606  return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
607  }
608  } else if (Operation *def = e.getDefiningOp()) {
609  // Handle index computation.
610  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
611  return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
612  // When still defined in new body, recurse into operands.
613  if (def->getBlock() == block) {
614  rewriter.setInsertionPoint(def);
615  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
616  rewriter.modifyOpInPlace(def, [&]() {
617  def->setOperand(
618  i, relinkBranch(env, rewriter, block, def->getOperand(i)));
619  });
620  }
621  }
622  }
623  return e;
624 }
625 
626 /// Recursively generates tensor expression.
627 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
629  return Value();
630 
631  linalg::GenericOp op = env.op();
632  Location loc = op.getLoc();
633  const TensorExp &exp = env.exp(e);
634  const auto kind = exp.kind;
636  return genTensorLoad(env, rewriter, e);
638  return genInvariantValue(env, e);
640  return env.getLoopVar(exp.loop);
641 
643  env.startCustomReduc(e); // enter custom
644 
645  // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
646  // based on the type of the other operand.
647  Value v0, v1;
650  v1 = genExp(env, rewriter, exp.children.e1);
651  v0 = constantZero(rewriter, loc, v1.getType());
654  v0 = genExp(env, rewriter, exp.children.e0);
655  v1 = constantZero(rewriter, loc, v0.getType());
656  } else {
657  v0 = genExp(env, rewriter, exp.children.e0);
658  v1 = genExp(env, rewriter, exp.children.e1);
659  }
660 
661  Value ee;
662  if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
663  // custom reduce did not receive a value
664  } else {
665  ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
666  if (ee &&
671  OpBuilder::InsertionGuard guard(rewriter);
672  ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
673  }
674  }
675 
677  env.endCustomReduc(); // exit custom
678 
680  env.merger().setExprValue(e, v0); // Preserve value for later use.
681 
682  return ee;
683 }
684 
685 /// Hoists loop invariant tensor loads for which indices have been exhausted.
686 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
687  LoopId curr, bool isStart) {
689  return;
690  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
691  // Inspect tensor indices.
692  linalg::GenericOp op = env.op();
693  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
694  const auto map = op.getMatchingIndexingMap(&t);
695  const auto stt = getSparseTensorType(t.get());
696  const Level lvlRank = stt.getLvlRank();
697  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
698  bool isCurrentLoop = curr == 0; // for scalar tensors
699  for (Level l = 0; l < lvlRank; l++) {
700  const AffineExpr a = map.getResult(l);
701  if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
702  return; // still in play
703  }
704  // All exhausted at current level.
705  if (!isCurrentLoop)
706  return;
707  // Generate code for a scalarized reduction or invariant. Note that
708  // because custom reduction lhs may occur several times in the IR,
709  // we have a built-in safety for only initializing and wrapping-up
710  // the scalarized reduction once.
711  OpOperand *lhs = op.getDpsInitOperand(0);
712  if (lhs == &t) {
713  // Start or end a scalarized reduction.
714  if (isStart) {
715  if (env.isCustomReduc()) {
716  if (!env.isReduc())
717  env.startReduc(exp, env.getCustomRedId());
718  } else {
719  env.startReduc(exp, genTensorLoad(env, builder, exp));
720  }
721  if (env.hasSparseOutput())
723  constantI1(builder, env.op().getLoc(), false));
724  } else {
725  if (!env.isCustomReduc() || env.isReduc())
726  genTensorStore(env, builder, exp, env.endReduc());
727  if (env.hasSparseOutput())
728  env.endValidLexInsert();
729  }
730  } else {
731  // Start or end loop invariant hoisting of a tensor load.
732  if (isStart) {
733  env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
734  } else {
735  env.merger().clearExprValue(exp);
736  }
737  }
738  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
739  env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
740  env.exp(exp).kind != TensorExp::Kind::kSynZero) {
741  // Traverse into the binary operations. Note that we only hoist
742  // tensor loads, since subsequent MLIR/LLVM passes know how to
743  // deal with all other kinds of derived loop invariants.
744  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
745  env.startCustomReduc(exp); // enter custom
746  const ExprId e0 = env.exp(exp).children.e0;
747  const ExprId e1 = env.exp(exp).children.e1;
748  genInvariants(env, builder, e0, curr, isStart);
749  genInvariants(env, builder, e1, curr, isStart);
750  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
751  env.endCustomReduc(); // exit custom
752  }
753 }
754 
755 /// Generates an expanded access pattern in innermost dimension.
756 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
757  bool isStart) {
758  linalg::GenericOp op = env.op();
759  OpOperand *lhs = op.getDpsInitOperand(0);
760  if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
761  return; // not needed at current level
762  assert(!env.isReduc());
763  // Generate start or end of an expanded access pattern. Note that because
764  // an expansion does not rely on the ongoing contents of the sparse storage
765  // scheme, we can use the original tensor as incoming SSA value (which
766  // simplifies codegen a bit). If expansion on the actual contents is ever
767  // needed, we will need to use the SSA value in the insertion chain instead.
768  Value tensor = lhs->get();
769  Location loc = op.getLoc();
770  if (isStart) {
771  auto dynShape = {ShapedType::kDynamic};
772  Type etp = cast<ShapedType>(tensor.getType()).getElementType();
773  Type t1 = MemRefType::get(dynShape, etp);
774  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
775  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
776  Type t4 = builder.getIndexType();
777  auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
778  assert(r.getNumResults() == 4);
779  env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
780  r.getResult(3));
781  } else {
782  SmallVector<Value> indices;
783  for (LoopId i = 0; i < curr; i++)
784  indices.push_back(env.emitter().getLoopIV(i));
785  Value values = env.getExpandValues();
786  Value filled = env.getExpandFilled();
787  Value added = env.getExpandAdded();
788  Value count = env.getExpandCount();
789  Value chain = env.getInsertionChain();
790  Value compress = builder.create<CompressOp>(loc, values, filled, added,
791  count, chain, indices);
792  env.updateInsertionChain(compress);
793  env.endExpand();
794  }
795 }
796 
797 /// Returns parallelization strategy. Any implicit loop in the Linalg
798 /// operation that is marked "parallel" is a candidate. Whether it is actually
799 /// converted to a parallel operation depends on the requested strategy.
800 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
801  // Reject parallelization of sparse output.
802  if (env.hasSparseOutput())
803  return false;
804  // Parallel loops on tensor expansion can cause data races.
805  if (env.isExpand())
806  return false;
807  // Inspect strategy.
808  switch (env.options().parallelizationStrategy) {
810  return false;
812  return isOuter && !isSparse;
814  return isOuter;
816  return !isSparse;
818  return true;
819  }
820  llvm_unreachable("unexpected parallelization strategy");
821 }
822 
823 /// Whether or not the current loop being generated should be parallized (if
824 /// possible) according to the configuration.
825 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
826  ArrayRef<TensorLevel> tidLvls) {
827  linalg::GenericOp op = env.op();
828  auto iteratorTypes = op.getIteratorTypesArray();
829  bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
830  // Queries the LT based on the tensor and loop id, as requested by
831  // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
832  // should be consistent with the LT indexed by <TensorId, Level>.
833  const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
834  return lt.hasSparseSemantic();
835  });
836  return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
837 }
838 
839 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
840 /// can either be a for loop or while loop depending on whether there is at most
841 /// one sparse level in the list.
843  ArrayRef<TensorLevel> tidLvls,
844  unsigned numCases, bool tryParallel,
845  bool needsUniv) {
846  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
847  // Construct while-loop with a parameter for each index.
848  return env.emitter().enterCoIterationOverTensorsAtLvls(
849  builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
850  needsUniv);
851  });
852  assert(loop);
853  return loop;
854 }
855 
856 /// Generates a for-loop or a while-loop, depending on whether it implements
857 /// singleton iteration or co-iteration over the given conjunction.
858 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
859  unsigned numCases, bool needsUniv,
860  ArrayRef<TensorLevel> tidLvls) {
861  bool tryParallel = shouldTryParallize(env, curr, tidLvls);
862  return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
863  needsUniv);
864 }
865 
866 /// Generates the induction structure for a while-loop.
867 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
868  bool needsUniv) {
869  Location loc = env.op().getLoc();
870  // Finalize each else branch of all if statements.
871  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
872  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
873  builder.getInsertionBlock()->getParentOp())) {
874  // Break on IfOp for slicing filtering.
875  if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
876  StringAttr::get(ifOp->getContext(), "slice"))
877  break;
878 
879  unsigned y = 0;
880  SmallVector<Value> yields;
881  if (env.isReduc()) {
882  yields.push_back(env.getReduc());
883  env.updateReduc(ifOp.getResult(y++));
884  if (env.isValidLexInsert()) {
885  yields.push_back(env.getValidLexInsert());
886  env.updateValidLexInsert(ifOp.getResult(y++));
887  }
888  }
889  if (env.isExpand()) {
890  yields.push_back(env.getExpandCount());
891  env.updateExpandCount(ifOp->getResult(y++));
892  }
893  if (env.getInsertionChain()) {
894  yields.push_back(env.getInsertionChain());
895  env.updateInsertionChain(ifOp->getResult(y++));
896  }
897  assert(y == yields.size());
898  builder.create<scf::YieldOp>(loc, yields);
899  builder.setInsertionPointAfter(ifOp);
900  }
901  }
902  // No need to set the insertion point here as LoopEmitter keeps track of the
903  // basic block where scf::Yield should be inserted.
904 }
905 
906 /// Generates a case region in the coiterate operation.
907 static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
908  unsigned caseIdx, LatPointId allCase,
909  LatPointId curCase,
910  MutableArrayRef<Value> reduc) {
911  assert(allCase == curCase || env.merger().latGT(allCase, curCase));
912  const BitVector &allCaseBits = env.merger().lat(allCase).simple;
913  const BitVector &curCaseBits = env.merger().lat(curCase).simple;
914 
915  /// Computes the subset of iterators that are valid in the current case being
916  /// generated.
917  I64BitSet caseBit(0);
918  for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
919  if (curCaseBits.test(set))
920  caseBit.set(idx);
921 
922  env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
923  caseIdx, reduc);
924 }
925 
926 /// Generates a single if-statement within a while-loop.
927 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
928  LatPointId p) {
929  Location loc = env.op().getLoc();
930  SmallVector<Type> types;
931  Value cond;
933  p, /*simple=*/true,
934  [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
935  bool isIdxRed) {
936  if (isIdxRed) {
937  // Since there is no 1:1 mapping from loop to level (multiple loops
938  // are required to resolve one level with non-trivial index
939  // expression), we need to reconstruct the tensor level types if this
940  // loop requires index reduction condition.
941  assert(lvl.has_value() && isUndefLT(lt));
942  auto stt = getSparseTensorType(env.op().getInputs()[tid]);
943  lt = stt.getLvlType(*lvl);
944  }
945  assert(curr == env.merger().loop(b));
946  Value clause;
947  if (lt.hasSparseSemantic()) {
948  assert(lvl.has_value());
949  const Value crd = env.emitter().getCoord(tid, *lvl);
950  const Value lvar = env.getLoopVar(curr);
951  clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
952  crd, lvar);
953  } else {
954  assert(lt.hasDenseSemantic() || isUndefLT(lt));
955  clause = constantI1(builder, loc, true);
956  }
957  cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
958  });
959  if (env.isReduc()) {
960  types.push_back(env.getReduc().getType());
961  if (env.isValidLexInsert())
962  types.push_back(env.getValidLexInsert().getType());
963  }
964  if (env.isExpand())
965  types.push_back(builder.getIndexType());
966  if (env.getInsertionChain())
967  types.push_back(env.getInsertionChain().getType());
968  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
969  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
970  return ifOp;
971 }
972 
973 /// Generates end of true branch of if-statement within a while-loop.
974 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
975  Value redInput, Value cntInput, Value insInput,
976  Value validIns) {
977  SmallVector<Value> operands;
978  if (env.isReduc()) {
979  operands.push_back(env.getReduc());
980  env.updateReduc(redInput);
981  if (env.isValidLexInsert()) {
982  // Any overlapping indices during a reduction creates a valid lex insert.
983  operands.push_back(constantI1(builder, env.op().getLoc(), true));
984  env.updateValidLexInsert(validIns);
985  }
986  }
987  if (env.isExpand()) {
988  operands.push_back(env.getExpandCount());
989  env.updateExpandCount(cntInput);
990  }
991  if (env.getInsertionChain()) {
992  operands.push_back(env.getInsertionChain());
993  env.updateInsertionChain(insInput);
994  }
995  if (!operands.empty())
996  builder.create<scf::YieldOp>(env.op().getLoc(), operands);
997  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
998 }
999 
1000 //===----------------------------------------------------------------------===//
1001 // Sparsifier synthesis methods (loop sequence).
1002 //===----------------------------------------------------------------------===//
1003 
1005  CodegenEnv &env, LatPointId li, LoopId curr,
1006  llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1007  const BitVector &simple = env.lat(li).simple;
1008  const TensorId outTid = env.merger().getOutTensorID();
1009  const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1010 
1011  unsigned numloopCond = 0;
1012  bool hasNonUnique = false;
1014  li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1015  LevelType lt, bool isIdxReduc) {
1016  if (simple[b]) {
1017  if (isIdxReduc) {
1018  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1019  numloopCond++;
1020  return;
1021  }
1022  if (isUndefLT(lt)) {
1023  // An undefined lt in the lattices, we probably mean to
1024  // generate a dense loop according to the synthetic tensor (for
1025  // invariants and sparse output tensor).
1026  if (env.merger().getSynTensorID() == tid) {
1027  // Coiterating with an invariant
1028  // e.g., out = prod(in[i][j] op invariant);
1029  // or a broadcast
1030  // e.g., out[i][j] = in[i] (j is undef for input)
1031  //
1032  // The level of the synthetic tensor is the current loop depth;
1033  // the rank of the synthetic tensor equals to number of loops.
1034  assert(curr == env.getCurrentDepth());
1035  lvl = curr;
1036  } else if (!lvl) {
1037  // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1038  return;
1039  }
1040  }
1041  hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1042  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1043  numloopCond++;
1044  } else if (lt.hasDenseSemantic() || isIdxReduc) {
1045  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1046  } else {
1047  assert(isUndefLT(lt));
1048  linalg::GenericOp op = env.op();
1049  if (tid >= op.getNumDpsInputs())
1050  // We only handle affine expression on input tensors (for now).
1051  return;
1052  OpOperand *operand = &op->getOpOperand(tid);
1053  const auto stt = getSparseTensorType(operand->get());
1054  // Non-annotated dense tensors requires no special handling.
1055  if (!stt.hasEncoding())
1056  return;
1057 
1058  ArrayRef<AffineExpr> affines =
1059  op.getMatchingIndexingMap(operand).getResults();
1060  const Level lvlRank = stt.getLvlRank();
1061  assert(affines.size() == static_cast<size_t>(lvlRank));
1062  for (Level l = 0; l < lvlRank; l++) {
1063  AffineExpr exp = affines[l];
1064  // Skip simple affine expression and non-dense levels (which
1065  // have their own filter loop).
1066  LevelType lt = stt.getLvlType(l);
1067  if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1068  continue;
1069 
1070  // Constant affine expression are handled in genLoop.
1071  if (!isa<AffineConstantExpr>(exp)) {
1072  bool isCurrentLoop = false;
1073  assert(curr == env.getCurrentDepth());
1074  if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1075  isCurrentLoop) {
1076  // If the compound affine is invariant and we are right at the
1077  // level. We need to generate the address according to the
1078  // affine expression. This is also the best place we can do it
1079  // to avoid putting it inside inner loops.
1080  callback(env.makeTensorLevel(tid, l), exp);
1081  }
1082  }
1083  }
1084  }
1085  });
1086 
1087  if (isDenseLT(env.lt(outTid, curr))) {
1088  auto stt = getSparseTensorType(env.op().getOutputs().front());
1089  // Note that we generate dense indices of the output tensor unconditionally,
1090  // since they may not appear in the lattice, but may be needed for
1091  // linearized env.
1092  // TODO: we should avoid introducing corner cases for all-dense sparse
1093  // tensors.
1094  if (stt.hasEncoding() && stt.isAllDense())
1095  callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1096  }
1097 
1098  if (numloopCond == 0) {
1099  // Corner cases where the loop bound is defined by a *unused* operand, in
1100  // this case, we just generate a dense "fake" loop by iterating over the
1101  // synthetic tensor.
1102  callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1103  numloopCond++;
1104  }
1105  // If we just need to one loop conditions and the conditions is not imposed on
1106  // non-unique level, the loop can be generated by a for loop.
1107  // Or, if we are generating sparse-iterator-based loops, we always generate
1108  // `sparse_tensor.iterate` regardless whether the level is unique or not.
1109  return numloopCond == 1 &&
1110  (!hasNonUnique || env.options().sparseEmitStrategy ==
1112 }
1113 
1114 /// Starts a loop sequence at given level. Returns true if
1115 /// the universal loop index must be maintained at this level.
1116 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1117  LoopId curr, LatSetId lts) {
1118  assert(!env.getLoopVar(curr));
1119  // Emit invariants at this loop sequence level.
1120  genInvariants(env, builder, exp, curr, /*isStart=*/true);
1121  // Emit access pattern expansion for sparse tensor output.
1122  genExpand(env, builder, curr, /*isStart=*/true);
1123  // Emit further initialization at this loop sequence level.
1124  const LatPointId l0 = env.set(lts)[0];
1125 
1126  SmallVector<TensorLevel> tidLvls;
1127  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1128  // TODO: remove this! The same tensor level might be added for multiple
1129  // times due to the special handling for all-dense "sparse" output tensor
1130  // (see L1038).
1131  if (llvm::find(tidLvls, tl) != tidLvls.end())
1132  return;
1133  tidLvls.emplace_back(tl);
1134  });
1135 
1136  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1137 
1138  // Maintain the universal index only if it is actually
1139  // consumed by a subsequent lattice point.
1140  for (const LatPointId li : env.set(lts).drop_front())
1141  if (!env.merger().hasAnySparse(env.lat(li).simple))
1142  return true;
1143 
1144  return false;
1145 }
1146 
1147 // Generates dense affine address for encoding.
1149  OpBuilder &builder, TensorId tid,
1150  Level startLvl) {
1151  // TODO: Handle affine expression on output tensor.
1152  linalg::GenericOp op = env.op();
1153  assert(tid < op.getNumDpsInputs());
1154  OpOperand *input = op.getDpsInputOperands()[tid];
1155  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1156  const auto enc = getSparseTensorEncoding(input->get().getType());
1157  if (enc) {
1158  const Location loc = op.getLoc();
1159  const TensorId tid = env.makeTensorId(input->getOperandNumber());
1160  const Level lvlRank = enc.getLvlRank();
1161  assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1162  for (Level l = startLvl; l < lvlRank; l++) {
1163  AffineExpr lvlExpr = lvlExprs[l];
1164  if (enc.getLvlType(l).hasDenseSemantic() &&
1165  isa<AffineConstantExpr>(lvlExpr))
1167  builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1168  else
1169  return; // break on first non-dense non-constant level
1170  }
1171  }
1172 }
1173 
1174 // We can generate address for constant affine expression before any loops
1175 // starting from the first level as they do not depend on anything.
1176 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1177 // levels can be determined before loops.
1179  RewriterBase &rewriter) {
1180  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1181  genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1182 }
1183 
1184 /// Returns true if the lattice bit can be iterated by a for loop.
1186  CodegenEnv &env, LatPointId li, LoopId curr,
1188  SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1189  return getAllTidLvlsInLatPoints(env, li, curr,
1190  [&](TensorLevel tl, AffineExpr exp) {
1191  if (exp)
1192  affineTidLvls.emplace_back(tl, exp);
1193  else
1194  tidLvls.emplace_back(tl);
1195  });
1196 }
1197 
1198 /// Starts a single loop in current sequence.
1199 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1200  OpBuilder &builder, LoopId curr,
1201  LatPointId li, unsigned numCases,
1202  bool needsUniv) {
1203  // TODO: numCases only used when generating iterator-based loops. Cleanup
1204  // after fully migration.
1205  // The set of tensors + lvls to generate loops on
1206  SmallVector<TensorLevel> tidLvls;
1207 
1208  // The set of dense tensors with non-trivial affine expression that just
1209  // becomes invariant and the address are generated at the current level.
1211  bool isSingleCond =
1212  translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1213 
1214  // Emit the for/while-loop control.
1215  Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1216  Location loc = env.op().getLoc();
1217  for (auto [tidLvl, exp] : affineTidLvls) {
1218  env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1219  }
1220 
1221  // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1222  // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1223  // on constant affines expression may now be determined.
1224  auto allTidLvls =
1225  llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1226  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1227  if (tid != env.merger().getOutTensorID() &&
1228  tid != env.merger().getSynTensorID())
1229  genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1230  }
1231 
1232  return std::make_pair(loop, isSingleCond);
1233 }
1234 
1235 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1236 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1237  LatPointId li, bool needsUniv, bool isSingleCond) {
1238  // Either a for-loop or a while-loop that iterates over a slice.
1239  if (isSingleCond) {
1240  // Any iteration creates a valid lex insert.
1241  if (env.isReduc() && env.isValidLexInsert())
1242  env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1243  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1244  // End a while-loop.
1245  finalizeWhileOp(env, rewriter, needsUniv);
1246  } else {
1247  needsUniv = false;
1248  }
1249  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1250  env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1251  return std::nullopt;
1252  });
1253  return needsUniv;
1254 }
1255 
1256 /// Ends a loop sequence at given level.
1257 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1258  unsigned at) {
1259  assert(!env.getLoopVar(at));
1260  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1261  // Unmark bookkeeping of invariants and loop index.
1262  genInvariants(env, builder, exp, at, /*isStart=*/false);
1263  // Finalize access pattern expansion for sparse tensor output.
1264  genExpand(env, builder, at, /*isStart=*/false);
1265 }
1266 
1267 /// Recursively generates code while computing iteration lattices in order
1268 /// to manage the complexity of implementing co-iteration over unions
1269 /// and intersections of sparse iterations spaces.
1270 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1271  LoopId curr) {
1272  assert(curr == env.getCurrentDepth());
1273 
1274  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1275  if (curr == env.getLoopNum()) {
1276  Value rhs = genExp(env, rewriter, exp);
1277  genTensorStore(env, rewriter, exp, rhs);
1278  return;
1279  }
1280 
1281  // Construct iteration lattices for current loop index.
1282  const LatSetId lts =
1283  env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1284 
1285  // Start a loop sequence.
1286  bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1287 
1288  // When using sparse-iterator-based loops, we only need one loops, as
1289  // opposed to a loop sequence, to cover all the iterator spaces.
1290  const unsigned lsize = env.set(lts).size();
1291  if (env.generatingSparseIterator()) {
1292  // Get the largest lattice point and start a loop.
1293  const LatPointId li = env.set(lts)[0];
1294  auto [loop, isSingleCond] =
1295  startLoop(env, rewriter, curr, li, lsize, needsUniv);
1296  assert(isSingleCond == llvm::isa<IterateOp>(loop));
1297  // We cannot change this to `for (const LatPointId li : env.set(lts))`
1298  // because the loop body causes data-movement which invalidates
1299  // the iterator.
1300  for (unsigned j = 0; j < lsize; j++) {
1301  const LatPointId lj = env.set(lts)[j];
1302  const ExprId ej = env.lat(lj).exp;
1303  // Recurse into body of each branch.
1304  if (!isSingleCond) {
1305  env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1306  genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1307  genStmt(env, rewriter, ej, curr + 1);
1308  // TODO: handle yield values.
1309  assert(reduc.empty() && "Not Implemented");
1310  rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1311  return std::nullopt;
1312  });
1313  // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1314  } else {
1315  genStmt(env, rewriter, ej, curr + 1);
1316  }
1317  }
1318  // End a loop.
1319  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1320  } else {
1321  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1322  for (unsigned i = 0; i < lsize; i++) {
1323  const LatPointId li = env.set(lts)[i];
1324  // Start a loop.
1325  auto [loop, isSingleCond] =
1326  startLoop(env, rewriter, curr, li, lsize, needsUniv);
1327 
1328  // Visit all lattices points with Li >= Lj to generate the
1329  // loop-body, possibly with if statements for coiteration.
1330  Value redInput = env.getReduc();
1331  Value cntInput = env.getExpandCount();
1332  Value insInput = env.getInsertionChain();
1333  Value validIns = env.getValidLexInsert();
1334  // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1335  // because the loop body causes data-movement which invalidates the
1336  // iterator.
1337  for (unsigned j = 0; j < lsize; j++) {
1338  const LatPointId lj = env.set(lts)[j];
1339  const ExprId ej = env.lat(lj).exp;
1340  if (li == lj || env.merger().latGT(li, lj)) {
1341  // Recurse into body of each branch.
1342  if (!isSingleCond) {
1343  scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1344  genStmt(env, rewriter, ej, curr + 1);
1345  endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1346  } else {
1347  genStmt(env, rewriter, ej, curr + 1);
1348  }
1349  }
1350  }
1351 
1352  // End a loop.
1353  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1354  }
1355  }
1356 
1357  // End a loop sequence.
1358  endLoopSeq(env, rewriter, exp, curr);
1359  assert(curr == env.getCurrentDepth());
1360 }
1361 
1362 /// Converts the result computed by the sparse kernel into the required form.
1363 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1364  linalg::GenericOp op = env.op();
1365  OpOperand *lhs = op.getDpsInitOperand(0);
1366  Value tensor = lhs->get();
1367  Type resType = tensor.getType();
1368  if (getSparseTensorEncoding(resType)) {
1369  // The sparse tensor rematerializes from the original sparse tensor's
1370  // underlying sparse storage format. For an insertion chain, the
1371  // tensor materializes from the chain with 'hasInserts' enabled.
1372  bool hasInserts = false;
1373  if (Value chain = env.getInsertionChain()) {
1374  hasInserts = true;
1375  tensor = chain;
1376  }
1377  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1378  } else {
1379  // To rematerialize an non-annotated tensor, simply load it
1380  // from the bufferized value.
1381  Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1382  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1383  }
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // Sparsifier rewriting methods.
1388 //===----------------------------------------------------------------------===//
1389 
1390 namespace {
1391 
1392 /// Sparse rewriting rule for generic Lingalg operation.
1393 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1394 public:
1395  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1396  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1397 
1398  LogicalResult matchAndRewrite(linalg::GenericOp op,
1399  PatternRewriter &rewriter) const override {
1400  // Only accept single output operations with pure tensor semantics.
1401  if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1402  return failure();
1403 
1404  // Only accept trivial affine indices.
1406  return failure();
1407 
1408  // Only accept scheduled loops.
1409  if (!op->hasAttr("sorted")) {
1410  return rewriter.notifyMatchFailure(
1411  op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1412  "before sparsification.");
1413  }
1414 
1415  // Must have been demapped as well if the generic op is sorted.
1417 
1418  // Sets up a code generation environment.
1419  const unsigned numTensors = op->getNumOperands();
1420  const unsigned numLoops = op.getNumLoops();
1421  bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1422  // If we have indexing map like (d0) -> (0, d0), there might be more
1423  // levels then loops because of the constant index, that means we can not
1424  // use numLoops as the upper bound for ranks of all tensors.
1425  // TODO: Constant indices are currently not support on sparse tensor, but
1426  // are allowed in non-annotated dense tensor. Support it, it would be
1427  // required for sparse tensor slice rank reducing too.
1428  Level maxLvlRank = 0;
1429  for (auto operand : op.getOperands()) {
1430  if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1431  maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1432  }
1433  }
1434 
1435  // Detects sparse annotations and translates the per-level sparsity
1436  // information for all tensors to loop indices in the kernel.
1437  CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1438  if (!findSparseAnnotations(env, needIdxRed))
1439  return failure();
1440 
1441  // Only standard reduction operations (add, sub, or, xor) that can be
1442  // sparsified by merely reducing the stored values are admissible. More
1443  // elaborate reduction operations (such as mul, and, min, max) would need
1444  // to know whether implicit zeros occur as well. They can still be
1445  // implemented with a custom reduction operation, accepted here as well.
1446  if (op.getNumReductionLoops() > 0) {
1447  Operation *yield = op.getRegion().front().getTerminator();
1448  assert(isa<linalg::YieldOp>(yield));
1449  Operation *redop = yield->getOperand(0).getDefiningOp();
1450  if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1451  !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1452  !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1453  !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1454  !isa<ReduceOp>(redop)) {
1455  return failure();
1456  }
1457  }
1458 
1459  // Constructs the tensor expressions tree from `op`, returns failure if the
1460  // tree can not be built or the tensor expression is inadmissible.
1461  if (failed(env.initTensorExp()))
1462  return failure();
1463 
1464  // Recursively generates code if admissible.
1465  env.startEmit(options.sparseEmitStrategy);
1466  genBuffers(env, rewriter);
1467  // TODO: Constant affine expression should be handled differently when using
1468  // slice-based codegen, it does not matter now because we already reject the
1469  // constant expression at an earlier stage.
1470  genInitConstantDenseAddress(env, rewriter);
1471  genStmt(env, rewriter, env.getExprId(), 0);
1472  genResult(env, rewriter);
1473  return success();
1474  }
1475 
1476 private:
1477  /// Options to control sparse code generation.
1479 };
1480 
1481 } // namespace
1482 
1483 /// Populates the given patterns list with rewriting rules required for
1484 /// the sparsification of linear algebra operations.
1487  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1488 }
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)
Starts a single loop in current sequence.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)
Generates a case region in the coiterate operation.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:228
void updateValidLexInsert(Value val)
Definition: CodegenEnv.cpp:256
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:103
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:83
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
Definition: CodegenEnv.cpp:200
unsigned getCurrentDepth() const
Definition: CodegenEnv.h:115
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:107
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:95
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:82
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:72
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:205
unsigned getLoopNum() const
Definition: CodegenEnv.h:89
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:195
bool generatingSparseIterator() const
Definition: CodegenEnv.h:52
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:266
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Definition: CodegenEnv.cpp:187
void startEmit(SparseEmitStrategy emitStrategy)
Definition: CodegenEnv.cpp:62
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:111
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:81
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:214
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:133
void startValidLexInsert(Value val)
Definition: CodegenEnv.cpp:251
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:75
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:84
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:243
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
const std::vector< Value > & getValBuffer() const
Definition: LoopEmitter.h:241
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
Definition: LoopEmitter.h:175
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
Definition: LoopEmitter.h:227
auto getLoopIVsRange() const
Get the range of values for all induction variables.
Definition: LoopEmitter.h:161
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
Definition: LoopEmitter.h:238
A class to handle all iteration lattice operations.
Definition: Merger.h:225
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:416
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:940
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:249
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:426
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:442
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:431
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:472
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:348
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:545
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:363
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1618
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:481
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
bool isUniqueLT(LevelType lt)
Definition: Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
unsigned TensorLevel
Definition: LoopEmitter.h:26
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
Definition: CodegenUtils.h:400
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Definition: SparseTensor.h:179
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
Definition: Enums.h:412
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options for the Sparsification pass.
Definition: Passes.h:93
SparseEmitStrategy sparseEmitStrategy
Definition: Passes.h:107
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:106
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:218
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:215
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition: Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.