MLIR  22.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenEnv.h"
14 #include "Utils/CodegenUtils.h"
15 #include "Utils/LoopEmitter.h"
16 
31 
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::sparse_tensor;
36 
37 //===----------------------------------------------------------------------===//
38 // Sparsifier analysis methods.
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true iff affine expression is invariant. Sets the
42 /// parameter `isCurrentLoop` when expression just became invariant.
43 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
44  switch (a.getKind()) {
45  case AffineExprKind::DimId: {
46  const LoopId i = cast<AffineDimExpr>(a).getPosition();
47  if (i + 1 == curr) {
48  isCurrentLoop = true;
49  return true; // becomes invariant at current loop
50  }
51  return i < curr; // invariant when already generated
52  }
54  case AffineExprKind::Mul: {
55  auto binOp = cast<AffineBinaryOpExpr>(a);
56  return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
57  isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
58  }
59  default: {
60  assert(isa<AffineConstantExpr>(a));
61  return true;
62  }
63  }
64 }
65 
66 /// Helper method to inspect affine expressions. Rejects cases where the
67 /// same index is used more than once. Also rejects compound affine
68 /// expressions in sparse dimensions.
69 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
70  LevelType lt, bool setLvlFormat = true) {
71  switch (a.getKind()) {
72  case AffineExprKind::DimId: {
73  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
74  if (!isUndefLT(merger.getLvlType(tid, idx)))
75  return false; // used more than once
76  if (setLvlFormat)
77  merger.setLevelAndType(tid, idx, lvl, lt);
78  return true;
79  }
83  assert(lt.hasDenseSemantic());
84  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
85  // We do not set dim level format for affine expression like d0 + d1 on
86  // either loop index at d0 or d1. We continue the recursion merely to
87  // check whether current affine is admissible or not.
88  return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
89  findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
90  }
91  // Falls through when it is a constant Affine
92  return true;
93  }
94  default:
95  return false;
96  }
97 }
98 
99 /// Helper method to inspect affine expressions for index variable reduction
100 /// based codegen. It finds the dependent index set for all tensor levels in the
101 /// current expression we are generating.
102 ///
103 /// For example, when handling A[i+j][j+k], we build the two way mapping in
104 /// merger between (tensor, level) pairs and their dependent index variable set:
105 /// A_0 <=> [i, j] and A_1 <=> [j, k]
106 ///
107 /// It rejects cases (returns false)
108 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
109 /// 2nd, when multiplication is used in the non-trivial index expression.
110 /// 3rd, when a constant operand is used in the non-trivial index expression.
111 ///
112 /// TODO: constant should be easy to handle.
113 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
114  AffineExpr a, LevelType lt, bool isSubExp = false,
115  int64_t coefficient = 1) {
116  switch (a.getKind()) {
117  case AffineExprKind::DimId: {
118  // Only allow positive coefficients on AffineDimExpr.
119  if (coefficient <= 0)
120  return false;
121 
122  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
123  if (!isUndefLT(merger.getLvlType(tensor, idx)))
124  return false; // used more than once, e.g., A[i][i]
125 
126  // TODO: Generalizes the following two cases. A[i] (with trivial index
127  // expression) can be treated as a special affine index expression. We do
128  // not necessarily need to differentiate them.
129  if (!isSubExp) {
130  assert(coefficient == 1);
131  merger.setLevelAndType(tensor, idx, lvl, lt);
132  }
133 
134  if (isSubExp) {
135  // The current loops appears in more than one affine expressions on the
136  // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
137  // used twice.
138  if (merger.hasDependentLvl(idx, tensor)) {
139  // TODO: This can be supported by coiterate slices if the loop idx is
140  // appeared on affine index for different tensor, or take slice on
141  // multiple dimensions when it is on the same tensor.
142  // E.g.,
143  // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
144  // d0_1 = getNextSliceOffset t0 along lvl0
145  // d0_2 = getNextSliceOffset t1 along lvl0
146  // if d0_1 == d0_2 then d0 = d0_1 = d0_1
147  // else increase min(d0_1, d0_2).
148  return false;
149  }
150  merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
151  }
152  return true;
153  }
155  case AffineExprKind::Mul: {
156  // TODO: Support index expression like `2 * d0`, we now only support more
157  // complicated cases like `2 * d0 + d1`.
158  if (!isSubExp)
159  return false;
160 
161  // TODO: Support Constant AffineExp for slice-based codegen
162  if (isa<AffineConstantExpr>(a))
163  llvm_unreachable("Not yet implemented");
164 
165  auto binOp = cast<AffineBinaryOpExpr>(a);
166  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
167  if (isa<AffineConstantExpr>(rhs))
168  std::swap(lhs, rhs);
169  // Must be in form of `constant * d`.
170  assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
171  int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
172  return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
173  }
174  case AffineExprKind::Add: {
175  auto binOp = cast<AffineBinaryOpExpr>(a);
176  return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
177  findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
178  }
179  default:
180  return false;
181  }
182 }
183 
184 /// Gets the total number of compound affine expressions in the
185 /// `getMatchingIndexingMap` for the given tensor. For the following inputs:
186 ///
187 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
188 ///
189 /// Returns 1 (because the first level is compressed and its corresponding
190 /// indexing-expression is `d0 + d1`)
192  Value tensor) {
193  // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
194  // we can't use `getRankedTensorType`/`getSparseTensorType` here.
195  // However, we don't need to handle `StorageSpecifierType`, so we
196  // can use `SparseTensorType` once we guard against non-tensors.
197  const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
198  if (!rtp)
199  return 0;
200  const SparseTensorType stt(rtp);
201 
202  const Level lvlRank = stt.getLvlRank();
203  const auto exprs = map.getResults();
204  assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
205  "AffineMap does not have dimension-rank many results");
206  unsigned num = 0;
207  for (Level l = 0; l < lvlRank; l++) {
208  if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
209  num++;
210  }
211  return num;
212 }
213 
214 /// Gets the total number of sparse levels with compound affine
215 /// expressions, summed over all operands of the `GenericOp`.
216 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
217  unsigned num = 0;
218  for (OpOperand &t : op->getOpOperands())
219  num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
220  t.get());
221  return num;
222 }
223 
224 // Returns true iff output has nontrivial affine indices.
225 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
226  OpOperand *out = op.getDpsInitOperand(0);
227  if (getSparseTensorType(out->get()).isAllDense())
228  return false;
229  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
230  out->get());
231 }
232 
233 /// Helper method to inspect sparse encodings in the tensor types.
234 /// Fills the per-dimension sparsity information for all tensors.
235 /// Returns true if the sparse annotations and affine subscript
236 /// expressions of all tensors are admissible. Returns false if
237 /// no annotations are found or inadmissible constructs occur.
238 /// We currently support two different ways to handle non-trivial index
239 /// expression on sparse tensors, and they accept different affine expressions.
240 /// When using dependent index reducton-based approach, it currently only
241 /// supports affine addition index expression.
242 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
243  bool annotated = false;
244  for (OpOperand &t : env.op()->getOpOperands()) {
245  const TensorId tid = env.makeTensorId(t.getOperandNumber());
246  const auto map = env.op().getMatchingIndexingMap(&t);
247  const auto enc = getSparseTensorEncoding(t.get().getType());
248  if (enc)
249  annotated = true;
250  const Level lvlRank = map.getNumResults();
251  assert(!enc || lvlRank == enc.getLvlRank());
252  assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
253  // We only need to do index reduction if there is at least one
254  // non-trivial index expression on sparse levels. If all non-trivial
255  // index expression is on dense levels, we can efficiently rely on
256  // the random access to locate the element.
257  bool needIdxReduc =
258  enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
259  // If then current tensor being inspected requires affine index, it need
260  // to be sliced.
261  for (Level l = 0; l < lvlRank; l++) {
262  const AffineExpr a = map.getResult(l);
263  const LevelType lt = enc.getLvlType(l);
264  if (idxReducBased && needIdxReduc) {
265  if (!findDepIdxSet(env.merger(), tid, l, a, lt))
266  return false; // inadmissible affine expression
267  } else {
268  if (!findAffine(env.merger(), tid, l, a, lt))
269  return false; // inadmissible affine expression
270  }
271  }
272  }
273  return annotated;
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // Sparsifier synthesis methods (statements and expressions).
278 //===----------------------------------------------------------------------===//
279 
280 /// Local bufferization of all dense and sparse data structures.
281 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
282  linalg::GenericOp op = env.op();
283  Location loc = op.getLoc();
284  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
285 
286  SmallVector<Range, 4> loopRange =
287  llvm::cast<linalg::LinalgOp>(op.getOperation())
288  .createLoopRanges(builder, loc);
289 
291  builder, loc,
292  /// Generates buffer for the output tensor.
293  /// Note that all sparse kernels assume that when all elements are written
294  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
295  /// to all zeroes and only nonzeroes values are computed and written out.
296  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
297  /// for the updates and no assumption on the original contents of the
298  /// output buffer is necessary.
299  [&op](OpBuilder &builder, Location loc, Value memref,
300  Value tensor) -> Value {
301  // Must not be a sparse tensor.
302  assert(!getSparseTensorEncoding(tensor.getType()));
303  // Two output tensor references should point to the same object.
304  OpOperand *lhs = op.getDpsInitOperand(0);
305  assert(lhs->get() == tensor);
306  // An output tensor can simply materialize from the buffer of the tensor
307  // that appears in the outs() clause. For updates, this has the
308  // advantage that only the nonzero value are involved in the
309  // computation, keeping the operation O(nnz). In all other cases, we are
310  // forced to zero out the buffer to enforce the assumption above, which
311  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
312  // O(nnz) for matrices).
313  // TODO: use better analysis to avoid zeroing out the buffer?
314  bool isInit = op.isInitTensor(lhs);
315  Value init = memref;
316  if (!isInit) {
317  Value zero = constantZero(builder, loc,
318  getElementTypeOrSelf(tensor.getType()));
319  linalg::FillOp::create(builder, loc, ValueRange{zero},
320  ValueRange{init});
321  }
322  return init;
323  },
324  [&loopRange](OpBuilder &b, Location loc, Level l) {
325  assert(l < loopRange.size());
326  return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
327  });
328 }
329 
330 /// Generates index for load/store on sparse tensor.
331 static Value genIndex(CodegenEnv &env, OpOperand *t) {
332  const auto map = env.op().getMatchingIndexingMap(t);
333  const auto stt = getSparseTensorType(t->get());
334  const Level lvlRank = stt.getLvlRank();
335  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
336  const AffineExpr a = map.getResult(lvlRank - 1);
337  assert(a.getKind() == AffineExprKind::DimId);
338  const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
339  return env.getLoopVar(idx);
340 }
341 
342 /// Generates subscript for load/store on a dense or sparse tensor.
343 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
344  SmallVectorImpl<Value> &args) {
345  const Location loc = env.op().getLoc();
346  const TensorId tid = env.makeTensorId(t->getOperandNumber());
347  const auto map = env.op().getMatchingIndexingMap(t);
348  const auto stt = getSparseTensorType(t->get());
349  if (stt.hasEncoding()) {
350  // For sparse tensors we only push the last-level's position onto `args`.
351  const auto pos = env.emitter().getValPosits(tid);
352  assert(!pos.empty());
353  args.append(pos);
354  // Simply returns the tensor to extract value using iterators.
356  return t->get();
357  } else {
358  // For dense tensors we push all level's coordinates onto `args`.
359  const Level lvlRank = stt.getLvlRank();
360  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
361  for (Level l = 0; l < lvlRank; l++) {
362  const auto lvlExpr = map.getResult(l);
363  const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
364  args.push_back(lvlCrd);
365  }
366  }
367  return env.emitter().getValBuffer()[tid];
368 }
369 
370 /// Generates insertion code to implement dynamic tensor load.
372  OpOperand *t) {
373  linalg::GenericOp op = env.op();
374  Location loc = op.getLoc();
375  // Direct lexicographic coordinate order, tensor loads as zero.
376  if (!env.isExpand()) {
377  Type tp = getElementTypeOrSelf(t->get().getType());
378  return constantZero(builder, loc, tp);
379  }
380  // Load from expanded access pattern.
381  Value index = genIndex(env, t);
382  return memref::LoadOp::create(builder, loc, env.getExpandValues(), index);
383 }
384 
385 /// Generates insertion code to implement dynamic tensor load for reduction.
387  OpOperand *t) {
388  linalg::GenericOp op = env.op();
389  Location loc = op.getLoc();
390  Value identity = env.getCustomRedId();
391  // Direct lexicographic coordinate order, tensor loads as identity.
392  if (!env.isExpand())
393  return identity;
394  // Load from expanded access pattern if filled, identity otherwise.
395  Value values = env.getExpandValues();
396  Value filled = env.getExpandFilled();
397  Value index = genIndex(env, t);
398  Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
399  Value valAtIndex = memref::LoadOp::create(builder, loc, values, index);
400  return arith::SelectOp::create(builder, loc, isFilled, valAtIndex, identity);
401 }
402 
403 static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
404  Value sparseOut, ValueRange ivs, Value v) {
405  scf::IfOp condInsert =
406  scf::IfOp::create(builder, loc, sparseOut.getType(), cond, true);
407  // True branch.
408  builder.setInsertionPointToStart(condInsert.thenBlock());
409  Value res = tensor::InsertOp::create(builder, loc, v, sparseOut, ivs);
410  scf::YieldOp::create(builder, loc, res);
411  // False branch.
412  builder.setInsertionPointToStart(condInsert.elseBlock());
413  scf::YieldOp::create(builder, loc, sparseOut);
414  // Value assignment.
415  builder.setInsertionPointAfter(condInsert);
416  return condInsert.getResult(0);
417 }
418 
419 /// Generates insertion code to implement dynamic tensor store.
420 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
421  Value rhs) {
422  linalg::GenericOp op = env.op();
423  Location loc = op.getLoc();
424  // Direct insertion in lexicographic coordinate order.
425  if (!env.isExpand()) {
426  const LoopId numLoops = op.getRank(t);
427  // Retrieves the first `numLoop` induction variables.
428  SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
429  env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
430  Value chain = env.getInsertionChain();
431  if (env.isValidLexInsert()) {
432  // Generates runtime check for a valid lex during reduction,
433  // to avoid inserting the identity value for empty reductions.
434  // if (validLexInsert) then
435  // insert(rhs) into chain
436  // return updated chain
437  // else
438  // return unmodified chain
439  Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
440  chain, ivs, rhs);
441  env.updateInsertionChain(out);
442  } else {
443  Value sparseOut;
444  if (!hasAnySparseType(env.op().getInputs().getTypes())) {
445  // This is an all-dense -> sparse kernel, test rhs != 0 before
446  // insertion.
447  Value nz = genIsNonzero(builder, loc, rhs);
448  sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
449  } else {
450  sparseOut = tensor::InsertOp::create(builder, loc, rhs, chain, ivs);
451  }
452  // Generates regular insertion chain.
453  env.updateInsertionChain(sparseOut);
454  }
455  return;
456  }
457  // Generates insertion code along expanded access pattern.
458  // if (!expFilled[i]) then
459  // expFilled[i] = true
460  // expAdded[inserts++] = i
461  // endif
462  // values[i] = rhs
463  Value values = env.getExpandValues();
464  Value filled = env.getExpandFilled();
465  Value added = env.getExpandAdded();
466  Value count = env.getExpandCount();
467  Value index = genIndex(env, t);
468  Value fval = constantI1(builder, loc, false);
469  Value tval = constantI1(builder, loc, true);
470  // If statement.
471  Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
472  Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
473  isFilled, fval);
474  scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIndexType(), cond,
475  /*else=*/true);
476  // True branch.
477  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
478  memref::StoreOp::create(builder, loc, tval, filled, index);
479  memref::StoreOp::create(builder, loc, index, added, count);
480  Value one = constantIndex(builder, loc, 1);
481  Value add = arith::AddIOp::create(builder, loc, count, one);
482  scf::YieldOp::create(builder, loc, add);
483  // False branch.
484  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
485  scf::YieldOp::create(builder, loc, count);
486  builder.setInsertionPointAfter(ifOp);
487  // Value assignment.
488  env.updateExpandCount(ifOp.getResult(0));
489  memref::StoreOp::create(builder, loc, rhs, values, index);
490 }
491 
492 /// Generates a load on a dense or sparse tensor.
493 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
494  // Test if the load was hoisted to a higher loop nest.
495  Value val = env.exp(exp).val;
496  if (val)
497  return val;
498  // Get tensor operand.
499  linalg::GenericOp op = env.op();
500  Location loc = op.getLoc();
501  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
502  // Fold binary-valued tensor into explicit value.
503  const auto stt = getSparseTensorType(t->get());
504  if (auto explVal = stt.getExplicitVal())
505  return genValFromAttr(builder, loc, explVal);
506  // Load during insertion.
507  if (env.isSparseOutput(t)) {
508  if (env.isCustomReduc())
509  return genInsertionLoadReduce(env, builder, t);
510  return genInsertionLoad(env, builder, t);
511  }
512 
513  // Actual load.
514  SmallVector<Value> args;
515  Value ptr = genSubscript(env, builder, t, args);
516  if (llvm::isa<TensorType>(ptr.getType())) {
517  assert(env.options().sparseEmitStrategy ==
519  return ExtractValOp::create(builder, loc, ptr,
520  llvm::getSingleElement(args));
521  }
522  return memref::LoadOp::create(builder, loc, ptr, args);
523 }
524 
525 /// Generates a store on a dense or sparse tensor.
526 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
527  Value rhs) {
528  // Only unary and binary are allowed to return an uninitialized rhs
529  // to indicate missing output. Or otherwise a custom reduction that
530  // received no value to accumulate.
531  if (!rhs) {
532  assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
533  env.exp(exp).kind == TensorExp::Kind::kBinary ||
534  env.exp(exp).kind == TensorExp::Kind::kReduce);
535  return;
536  }
537  // Test if this is a scalarized reduction.
538  if (env.isReduc()) {
539  env.updateReduc(rhs);
540  return;
541  }
542  // Regular store.
543  linalg::GenericOp op = env.op();
544  Location loc = op.getLoc();
545  OpOperand *t = op.getDpsInitOperand(0);
546  if (!env.isSparseOutput(t)) {
547  SmallVector<Value> args;
548  Value ptr = genSubscript(env, builder, t, args);
549  memref::StoreOp::create(builder, loc, rhs, ptr, args);
550  return;
551  }
552  // Store during sparse insertion.
553  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
554  genInsertionStore(env, builder, t, rhs);
555  return;
556  }
557  // Select operation insertion.
558  Value chain = env.getInsertionChain();
559  scf::IfOp ifOp =
560  scf::IfOp::create(builder, loc, chain.getType(), rhs, /*else=*/true);
561  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
562  // Existing value was preserved to be used here.
563  assert(env.exp(exp).val);
564  Value v0 = env.exp(exp).val;
565  genInsertionStore(env, builder, t, v0);
566  env.merger().clearExprValue(exp);
567  // Yield modified insertion chain along true branch.
568  Value mchain = env.getInsertionChain();
569  scf::YieldOp::create(builder, op.getLoc(), mchain);
570  // Yield original insertion chain along false branch.
571  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
572  scf::YieldOp::create(builder, loc, chain);
573  // Done with if statement.
574  env.updateInsertionChain(ifOp->getResult(0));
575  builder.setInsertionPointAfter(ifOp);
576 }
577 
578 /// Generates an invariant value.
579 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
580  return env.exp(exp).val;
581 }
582 
583 /// Semi-ring branches are simply inlined by the sparsifier. Prior
584 /// analysis has verified that all computations are "local" to the inlined
585 /// branch or otherwise invariantly defined outside the loop nest, with the
586 /// exception of index computations, which need to be relinked to actual
587 /// inlined cloned code.
588 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
589  Value e) {
590  if (auto arg = dyn_cast<BlockArgument>(e)) {
591  // Direct arguments of the original linalg op must be converted
592  // into dense tensor loads. Note that we should not encounter
593  // anything else. This needs to be verified by semi-ring ops.
594  linalg::GenericOp op = env.op();
595  if (arg.getOwner()->getParentOp() == op) {
596  const TensorId tid = env.makeTensorId(arg.getArgNumber());
597  OpOperand *t = &op->getOpOperand(tid);
598  assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
599  SmallVector<Value> args;
600  Value ptr = genSubscript(env, rewriter, t, args);
601  return memref::LoadOp::create(rewriter, op.getLoc(), ptr, args);
602  }
603  } else if (Operation *def = e.getDefiningOp()) {
604  // Handle index computation.
605  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
606  return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
607  // When still defined in new body, recurse into operands.
608  if (def->getBlock() == block) {
609  rewriter.setInsertionPoint(def);
610  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
611  rewriter.modifyOpInPlace(def, [&]() {
612  def->setOperand(
613  i, relinkBranch(env, rewriter, block, def->getOperand(i)));
614  });
615  }
616  }
617  }
618  return e;
619 }
620 
621 /// Recursively generates tensor expression.
622 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
624  return Value();
625 
626  linalg::GenericOp op = env.op();
627  Location loc = op.getLoc();
628  const TensorExp &exp = env.exp(e);
629  const auto kind = exp.kind;
631  return genTensorLoad(env, rewriter, e);
633  return genInvariantValue(env, e);
635  return env.getLoopVar(exp.loop);
636 
638  env.startCustomReduc(e); // enter custom
639 
640  // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
641  // based on the type of the other operand.
642  Value v0, v1;
645  v1 = genExp(env, rewriter, exp.children.e1);
646  v0 = constantZero(rewriter, loc, v1.getType());
649  v0 = genExp(env, rewriter, exp.children.e0);
650  v1 = constantZero(rewriter, loc, v0.getType());
651  } else {
652  v0 = genExp(env, rewriter, exp.children.e0);
653  v1 = genExp(env, rewriter, exp.children.e1);
654  }
655 
656  Value ee;
657  if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
658  // custom reduce did not receive a value
659  } else {
660  ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
661  if (ee &&
666  OpBuilder::InsertionGuard guard(rewriter);
667  ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
668  }
669  }
670 
672  env.endCustomReduc(); // exit custom
673 
675  env.merger().setExprValue(e, v0); // Preserve value for later use.
676 
677  return ee;
678 }
679 
680 /// Hoists loop invariant tensor loads for which indices have been exhausted.
681 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
682  LoopId curr, bool isStart) {
684  return;
685  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
686  // Inspect tensor indices.
687  linalg::GenericOp op = env.op();
688  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
689  const auto map = op.getMatchingIndexingMap(&t);
690  const auto stt = getSparseTensorType(t.get());
691  const Level lvlRank = stt.getLvlRank();
692  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
693  bool isCurrentLoop = curr == 0; // for scalar tensors
694  for (Level l = 0; l < lvlRank; l++) {
695  const AffineExpr a = map.getResult(l);
696  if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
697  return; // still in play
698  }
699  // All exhausted at current level.
700  if (!isCurrentLoop)
701  return;
702  // Generate code for a scalarized reduction or invariant. Note that
703  // because custom reduction lhs may occur several times in the IR,
704  // we have a built-in safety for only initializing and wrapping-up
705  // the scalarized reduction once.
706  OpOperand *lhs = op.getDpsInitOperand(0);
707  if (lhs == &t) {
708  // Start or end a scalarized reduction.
709  if (isStart) {
710  if (env.isCustomReduc()) {
711  if (!env.isReduc())
712  env.startReduc(exp, env.getCustomRedId());
713  } else {
714  env.startReduc(exp, genTensorLoad(env, builder, exp));
715  }
716  if (env.hasSparseOutput())
718  constantI1(builder, env.op().getLoc(), false));
719  } else {
720  if (!env.isCustomReduc() || env.isReduc())
721  genTensorStore(env, builder, exp, env.endReduc());
722  if (env.hasSparseOutput())
723  env.endValidLexInsert();
724  }
725  } else {
726  // Start or end loop invariant hoisting of a tensor load.
727  if (isStart) {
728  env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
729  } else {
730  env.merger().clearExprValue(exp);
731  }
732  }
733  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
734  env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
735  env.exp(exp).kind != TensorExp::Kind::kSynZero) {
736  // Traverse into the binary operations. Note that we only hoist
737  // tensor loads, since subsequent MLIR/LLVM passes know how to
738  // deal with all other kinds of derived loop invariants.
739  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
740  env.startCustomReduc(exp); // enter custom
741  const ExprId e0 = env.exp(exp).children.e0;
742  const ExprId e1 = env.exp(exp).children.e1;
743  genInvariants(env, builder, e0, curr, isStart);
744  genInvariants(env, builder, e1, curr, isStart);
745  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
746  env.endCustomReduc(); // exit custom
747  }
748 }
749 
750 /// Generates an expanded access pattern in innermost dimension.
751 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
752  bool isStart) {
753  linalg::GenericOp op = env.op();
754  OpOperand *lhs = op.getDpsInitOperand(0);
755  if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
756  return; // not needed at current level
757  assert(!env.isReduc());
758  // Generate start or end of an expanded access pattern. Note that because
759  // an expansion does not rely on the ongoing contents of the sparse storage
760  // scheme, we can use the original tensor as incoming SSA value (which
761  // simplifies codegen a bit). If expansion on the actual contents is ever
762  // needed, we will need to use the SSA value in the insertion chain instead.
763  Value tensor = lhs->get();
764  Location loc = op.getLoc();
765  if (isStart) {
766  auto dynShape = {ShapedType::kDynamic};
767  Type etp = cast<ShapedType>(tensor.getType()).getElementType();
768  Type t1 = MemRefType::get(dynShape, etp);
769  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
770  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
771  Type t4 = builder.getIndexType();
772  auto r =
773  ExpandOp::create(builder, loc, TypeRange({t1, t2, t3, t4}), tensor);
774  assert(r.getNumResults() == 4);
775  env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
776  r.getResult(3));
777  } else {
778  SmallVector<Value> indices;
779  for (LoopId i = 0; i < curr; i++)
780  indices.push_back(env.emitter().getLoopIV(i));
781  Value values = env.getExpandValues();
782  Value filled = env.getExpandFilled();
783  Value added = env.getExpandAdded();
784  Value count = env.getExpandCount();
785  Value chain = env.getInsertionChain();
786  Value compress = CompressOp::create(builder, loc, values, filled, added,
787  count, chain, indices);
788  env.updateInsertionChain(compress);
789  env.endExpand();
790  }
791 }
792 
793 /// Returns parallelization strategy. Any implicit loop in the Linalg
794 /// operation that is marked "parallel" is a candidate. Whether it is actually
795 /// converted to a parallel operation depends on the requested strategy.
796 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
797  // Reject parallelization of sparse output.
798  if (env.hasSparseOutput())
799  return false;
800  // Parallel loops on tensor expansion can cause data races.
801  if (env.isExpand())
802  return false;
803  // Inspect strategy.
804  switch (env.options().parallelizationStrategy) {
806  return false;
808  return isOuter && !isSparse;
810  return isOuter;
812  return !isSparse;
814  return true;
815  }
816  llvm_unreachable("unexpected parallelization strategy");
817 }
818 
819 /// Whether or not the current loop being generated should be parallized (if
820 /// possible) according to the configuration.
821 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
822  ArrayRef<TensorLevel> tidLvls) {
823  linalg::GenericOp op = env.op();
824  auto iteratorTypes = op.getIteratorTypesArray();
825  bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
826  // Queries the LT based on the tensor and loop id, as requested by
827  // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
828  // should be consistent with the LT indexed by <TensorId, Level>.
829  const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
830  return lt.hasSparseSemantic();
831  });
832  return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
833 }
834 
835 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
836 /// can either be a for loop or while loop depending on whether there is at most
837 /// one sparse level in the list.
839  ArrayRef<TensorLevel> tidLvls,
840  unsigned numCases, bool tryParallel,
841  bool needsUniv) {
842  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
843  // Construct while-loop with a parameter for each index.
844  return env.emitter().enterCoIterationOverTensorsAtLvls(
845  builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
846  needsUniv);
847  });
848  assert(loop);
849  return loop;
850 }
851 
852 /// Generates a for-loop or a while-loop, depending on whether it implements
853 /// singleton iteration or co-iteration over the given conjunction.
854 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
855  unsigned numCases, bool needsUniv,
856  ArrayRef<TensorLevel> tidLvls) {
857  bool tryParallel = shouldTryParallize(env, curr, tidLvls);
858  return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
859  needsUniv);
860 }
861 
862 /// Generates the induction structure for a while-loop.
863 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
864  bool needsUniv) {
865  Location loc = env.op().getLoc();
866  // Finalize each else branch of all if statements.
867  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
868  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
869  builder.getInsertionBlock()->getParentOp())) {
870  // Break on IfOp for slicing filtering.
871  if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
872  StringAttr::get(ifOp->getContext(), "slice"))
873  break;
874 
875  unsigned y = 0;
876  SmallVector<Value> yields;
877  if (env.isReduc()) {
878  yields.push_back(env.getReduc());
879  env.updateReduc(ifOp.getResult(y++));
880  if (env.isValidLexInsert()) {
881  yields.push_back(env.getValidLexInsert());
882  env.updateValidLexInsert(ifOp.getResult(y++));
883  }
884  }
885  if (env.isExpand()) {
886  yields.push_back(env.getExpandCount());
887  env.updateExpandCount(ifOp->getResult(y++));
888  }
889  if (env.getInsertionChain()) {
890  yields.push_back(env.getInsertionChain());
891  env.updateInsertionChain(ifOp->getResult(y++));
892  }
893  assert(y == yields.size());
894  scf::YieldOp::create(builder, loc, yields);
895  builder.setInsertionPointAfter(ifOp);
896  }
897  }
898  // No need to set the insertion point here as LoopEmitter keeps track of the
899  // basic block where scf::Yield should be inserted.
900 }
901 
902 /// Generates a case region in the coiterate operation.
903 static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
904  unsigned caseIdx, LatPointId allCase,
905  LatPointId curCase,
906  MutableArrayRef<Value> reduc) {
907  assert(allCase == curCase || env.merger().latGT(allCase, curCase));
908  const BitVector &allCaseBits = env.merger().lat(allCase).simple;
909  const BitVector &curCaseBits = env.merger().lat(curCase).simple;
910 
911  /// Computes the subset of iterators that are valid in the current case being
912  /// generated.
913  I64BitSet caseBit(0);
914  for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
915  if (curCaseBits.test(set))
916  caseBit.set(idx);
917 
918  env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
919  caseIdx, reduc);
920 }
921 
922 /// Generates a single if-statement within a while-loop.
923 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
924  LatPointId p) {
925  Location loc = env.op().getLoc();
926  SmallVector<Type> types;
927  Value cond;
929  p, /*simple=*/true,
930  [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
931  bool isIdxRed) {
932  if (isIdxRed) {
933  // Since there is no 1:1 mapping from loop to level (multiple loops
934  // are required to resolve one level with non-trivial index
935  // expression), we need to reconstruct the tensor level types if this
936  // loop requires index reduction condition.
937  assert(lvl.has_value() && isUndefLT(lt));
938  auto stt = getSparseTensorType(env.op().getInputs()[tid]);
939  lt = stt.getLvlType(*lvl);
940  }
941  assert(curr == env.merger().loop(b));
942  Value clause;
943  if (lt.hasSparseSemantic()) {
944  assert(lvl.has_value());
945  const Value crd = env.emitter().getCoord(tid, *lvl);
946  const Value lvar = env.getLoopVar(curr);
947  clause = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
948  crd, lvar);
949  } else {
950  assert(lt.hasDenseSemantic() || isUndefLT(lt));
951  clause = constantI1(builder, loc, true);
952  }
953  cond =
954  cond ? arith::AndIOp::create(builder, loc, cond, clause) : clause;
955  });
956  if (env.isReduc()) {
957  types.push_back(env.getReduc().getType());
958  if (env.isValidLexInsert())
959  types.push_back(env.getValidLexInsert().getType());
960  }
961  if (env.isExpand())
962  types.push_back(builder.getIndexType());
963  if (env.getInsertionChain())
964  types.push_back(env.getInsertionChain().getType());
965  scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
966  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
967  return ifOp;
968 }
969 
970 /// Generates end of true branch of if-statement within a while-loop.
971 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
972  Value redInput, Value cntInput, Value insInput,
973  Value validIns) {
974  SmallVector<Value> operands;
975  if (env.isReduc()) {
976  operands.push_back(env.getReduc());
977  env.updateReduc(redInput);
978  if (env.isValidLexInsert()) {
979  // Any overlapping indices during a reduction creates a valid lex insert.
980  operands.push_back(constantI1(builder, env.op().getLoc(), true));
981  env.updateValidLexInsert(validIns);
982  }
983  }
984  if (env.isExpand()) {
985  operands.push_back(env.getExpandCount());
986  env.updateExpandCount(cntInput);
987  }
988  if (env.getInsertionChain()) {
989  operands.push_back(env.getInsertionChain());
990  env.updateInsertionChain(insInput);
991  }
992  if (!operands.empty())
993  scf::YieldOp::create(builder, env.op().getLoc(), operands);
994  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
995 }
996 
997 //===----------------------------------------------------------------------===//
998 // Sparsifier synthesis methods (loop sequence).
999 //===----------------------------------------------------------------------===//
1000 
1002  CodegenEnv &env, LatPointId li, LoopId curr,
1003  llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1004  const BitVector &simple = env.lat(li).simple;
1005  const TensorId outTid = env.merger().getOutTensorID();
1006  const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1007 
1008  unsigned numloopCond = 0;
1009  bool hasNonUnique = false;
1011  li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1012  LevelType lt, bool isIdxReduc) {
1013  if (simple[b]) {
1014  if (isIdxReduc) {
1015  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1016  numloopCond++;
1017  return;
1018  }
1019  if (isUndefLT(lt)) {
1020  // An undefined lt in the lattices, we probably mean to
1021  // generate a dense loop according to the synthetic tensor (for
1022  // invariants and sparse output tensor).
1023  if (env.merger().getSynTensorID() == tid) {
1024  // Coiterating with an invariant
1025  // e.g., out = prod(in[i][j] op invariant);
1026  // or a broadcast
1027  // e.g., out[i][j] = in[i] (j is undef for input)
1028  //
1029  // The level of the synthetic tensor is the current loop depth;
1030  // the rank of the synthetic tensor equals to number of loops.
1031  assert(curr == env.getCurrentDepth());
1032  lvl = curr;
1033  } else if (!lvl) {
1034  // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1035  return;
1036  }
1037  }
1038  hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1039  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1040  numloopCond++;
1041  } else if (lt.hasDenseSemantic() || isIdxReduc) {
1042  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1043  } else {
1044  assert(isUndefLT(lt));
1045  linalg::GenericOp op = env.op();
1046  if (tid >= op.getNumDpsInputs())
1047  // We only handle affine expression on input tensors (for now).
1048  return;
1049  OpOperand *operand = &op->getOpOperand(tid);
1050  const auto stt = getSparseTensorType(operand->get());
1051  // Non-annotated dense tensors requires no special handling.
1052  if (!stt.hasEncoding())
1053  return;
1054 
1055  ArrayRef<AffineExpr> affines =
1056  op.getMatchingIndexingMap(operand).getResults();
1057  const Level lvlRank = stt.getLvlRank();
1058  assert(affines.size() == static_cast<size_t>(lvlRank));
1059  for (Level l = 0; l < lvlRank; l++) {
1060  AffineExpr exp = affines[l];
1061  // Skip simple affine expression and non-dense levels (which
1062  // have their own filter loop).
1063  LevelType lt = stt.getLvlType(l);
1064  if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1065  continue;
1066 
1067  // Constant affine expression are handled in genLoop.
1068  if (!isa<AffineConstantExpr>(exp)) {
1069  bool isCurrentLoop = false;
1070  assert(curr == env.getCurrentDepth());
1071  if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1072  isCurrentLoop) {
1073  // If the compound affine is invariant and we are right at the
1074  // level. We need to generate the address according to the
1075  // affine expression. This is also the best place we can do it
1076  // to avoid putting it inside inner loops.
1077  callback(env.makeTensorLevel(tid, l), exp);
1078  }
1079  }
1080  }
1081  }
1082  });
1083 
1084  if (isDenseLT(env.lt(outTid, curr))) {
1085  auto stt = getSparseTensorType(env.op().getOutputs().front());
1086  // Note that we generate dense indices of the output tensor unconditionally,
1087  // since they may not appear in the lattice, but may be needed for
1088  // linearized env.
1089  // TODO: we should avoid introducing corner cases for all-dense sparse
1090  // tensors.
1091  if (stt.hasEncoding() && stt.isAllDense())
1092  callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1093  }
1094 
1095  if (numloopCond == 0) {
1096  // Corner cases where the loop bound is defined by a *unused* operand, in
1097  // this case, we just generate a dense "fake" loop by iterating over the
1098  // synthetic tensor.
1099  callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1100  numloopCond++;
1101  }
1102  // If we just need to one loop conditions and the conditions is not imposed on
1103  // non-unique level, the loop can be generated by a for loop.
1104  // Or, if we are generating sparse-iterator-based loops, we always generate
1105  // `sparse_tensor.iterate` regardless whether the level is unique or not.
1106  return numloopCond == 1 &&
1107  (!hasNonUnique || env.options().sparseEmitStrategy ==
1109 }
1110 
1111 /// Starts a loop sequence at given level. Returns true if
1112 /// the universal loop index must be maintained at this level.
1113 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1114  LoopId curr, LatSetId lts) {
1115  assert(!env.getLoopVar(curr));
1116  // Emit invariants at this loop sequence level.
1117  genInvariants(env, builder, exp, curr, /*isStart=*/true);
1118  // Emit access pattern expansion for sparse tensor output.
1119  genExpand(env, builder, curr, /*isStart=*/true);
1120  // Emit further initialization at this loop sequence level.
1121  const LatPointId l0 = env.set(lts)[0];
1122 
1123  SmallVector<TensorLevel> tidLvls;
1124  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1125  // TODO: remove this! The same tensor level might be added for multiple
1126  // times due to the special handling for all-dense "sparse" output tensor
1127  // (see L1038).
1128  if (llvm::is_contained(tidLvls, tl))
1129  return;
1130  tidLvls.emplace_back(tl);
1131  });
1132 
1133  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1134 
1135  // Maintain the universal index only if it is actually
1136  // consumed by a subsequent lattice point.
1137  for (const LatPointId li : env.set(lts).drop_front())
1138  if (!env.merger().hasAnySparse(env.lat(li).simple))
1139  return true;
1140 
1141  return false;
1142 }
1143 
1144 // Generates dense affine address for encoding.
1146  OpBuilder &builder, TensorId tid,
1147  Level startLvl) {
1148  // TODO: Handle affine expression on output tensor.
1149  linalg::GenericOp op = env.op();
1150  assert(tid < op.getNumDpsInputs());
1151  OpOperand *input = op.getDpsInputOperands()[tid];
1152  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1153  const auto enc = getSparseTensorEncoding(input->get().getType());
1154  if (enc) {
1155  const Location loc = op.getLoc();
1156  const TensorId tid = env.makeTensorId(input->getOperandNumber());
1157  const Level lvlRank = enc.getLvlRank();
1158  assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1159  for (Level l = startLvl; l < lvlRank; l++) {
1160  AffineExpr lvlExpr = lvlExprs[l];
1161  if (enc.getLvlType(l).hasDenseSemantic() &&
1162  isa<AffineConstantExpr>(lvlExpr))
1164  builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1165  else
1166  return; // break on first non-dense non-constant level
1167  }
1168  }
1169 }
1170 
1171 // We can generate address for constant affine expression before any loops
1172 // starting from the first level as they do not depend on anything.
1173 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1174 // levels can be determined before loops.
1176  RewriterBase &rewriter) {
1177  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1178  genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1179 }
1180 
1181 /// Returns true if the lattice bit can be iterated by a for loop.
1183  CodegenEnv &env, LatPointId li, LoopId curr,
1185  SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1186  return getAllTidLvlsInLatPoints(env, li, curr,
1187  [&](TensorLevel tl, AffineExpr exp) {
1188  if (exp)
1189  affineTidLvls.emplace_back(tl, exp);
1190  else
1191  tidLvls.emplace_back(tl);
1192  });
1193 }
1194 
1195 /// Starts a single loop in current sequence.
1196 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1197  OpBuilder &builder, LoopId curr,
1198  LatPointId li, unsigned numCases,
1199  bool needsUniv) {
1200  // TODO: numCases only used when generating iterator-based loops. Cleanup
1201  // after fully migration.
1202  // The set of tensors + lvls to generate loops on
1203  SmallVector<TensorLevel> tidLvls;
1204 
1205  // The set of dense tensors with non-trivial affine expression that just
1206  // becomes invariant and the address are generated at the current level.
1208  bool isSingleCond =
1209  translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1210 
1211  // Emit the for/while-loop control.
1212  Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1213  Location loc = env.op().getLoc();
1214  for (auto [tidLvl, exp] : affineTidLvls) {
1215  env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1216  }
1217 
1218  // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1219  // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1220  // on constant affines expression may now be determined.
1221  auto allTidLvls =
1222  llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1223  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1224  if (tid != env.merger().getOutTensorID() &&
1225  tid != env.merger().getSynTensorID())
1226  genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1227  }
1228 
1229  return std::make_pair(loop, isSingleCond);
1230 }
1231 
1232 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1233 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1234  LatPointId li, bool needsUniv, bool isSingleCond) {
1235  // Either a for-loop or a while-loop that iterates over a slice.
1236  if (isSingleCond) {
1237  // Any iteration creates a valid lex insert.
1238  if (env.isReduc() && env.isValidLexInsert())
1239  env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1240  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1241  // End a while-loop.
1242  finalizeWhileOp(env, rewriter, needsUniv);
1243  } else {
1244  needsUniv = false;
1245  }
1246  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1247  env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1248  return std::nullopt;
1249  });
1250  return needsUniv;
1251 }
1252 
1253 /// Ends a loop sequence at given level.
1254 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1255  unsigned at) {
1256  assert(!env.getLoopVar(at));
1257  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1258  // Unmark bookkeeping of invariants and loop index.
1259  genInvariants(env, builder, exp, at, /*isStart=*/false);
1260  // Finalize access pattern expansion for sparse tensor output.
1261  genExpand(env, builder, at, /*isStart=*/false);
1262 }
1263 
1264 /// Recursively generates code while computing iteration lattices in order
1265 /// to manage the complexity of implementing co-iteration over unions
1266 /// and intersections of sparse iterations spaces.
1267 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1268  LoopId curr) {
1269  assert(curr == env.getCurrentDepth());
1270 
1271  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1272  if (curr == env.getLoopNum()) {
1273  Value rhs = genExp(env, rewriter, exp);
1274  genTensorStore(env, rewriter, exp, rhs);
1275  return;
1276  }
1277 
1278  // Construct iteration lattices for current loop index.
1279  const LatSetId lts =
1280  env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1281 
1282  // Start a loop sequence.
1283  bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1284 
1285  // When using sparse-iterator-based loops, we only need one loops, as
1286  // opposed to a loop sequence, to cover all the iterator spaces.
1287  const unsigned lsize = env.set(lts).size();
1288  if (env.generatingSparseIterator()) {
1289  // Get the largest lattice point and start a loop.
1290  const LatPointId li = env.set(lts)[0];
1291  auto [loop, isSingleCond] =
1292  startLoop(env, rewriter, curr, li, lsize, needsUniv);
1293  assert(isSingleCond == llvm::isa<IterateOp>(loop));
1294  // We cannot change this to `for (const LatPointId li : env.set(lts))`
1295  // because the loop body causes data-movement which invalidates
1296  // the iterator.
1297  for (unsigned j = 0; j < lsize; j++) {
1298  const LatPointId lj = env.set(lts)[j];
1299  const ExprId ej = env.lat(lj).exp;
1300  // Recurse into body of each branch.
1301  if (!isSingleCond) {
1302  env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1303  genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1304  genStmt(env, rewriter, ej, curr + 1);
1305  // TODO: handle yield values.
1306  assert(reduc.empty() && "Not Implemented");
1307  sparse_tensor::YieldOp::create(rewriter, env.op().getLoc());
1308  return std::nullopt;
1309  });
1310  // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1311  } else {
1312  genStmt(env, rewriter, ej, curr + 1);
1313  }
1314  }
1315  // End a loop.
1316  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1317  } else {
1318  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1319  for (unsigned i = 0; i < lsize; i++) {
1320  const LatPointId li = env.set(lts)[i];
1321  // Start a loop.
1322  auto [loop, isSingleCond] =
1323  startLoop(env, rewriter, curr, li, lsize, needsUniv);
1324 
1325  // Visit all lattices points with Li >= Lj to generate the
1326  // loop-body, possibly with if statements for coiteration.
1327  Value redInput = env.getReduc();
1328  Value cntInput = env.getExpandCount();
1329  Value insInput = env.getInsertionChain();
1330  Value validIns = env.getValidLexInsert();
1331  // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1332  // because the loop body causes data-movement which invalidates the
1333  // iterator.
1334  for (unsigned j = 0; j < lsize; j++) {
1335  const LatPointId lj = env.set(lts)[j];
1336  const ExprId ej = env.lat(lj).exp;
1337  if (li == lj || env.merger().latGT(li, lj)) {
1338  // Recurse into body of each branch.
1339  if (!isSingleCond) {
1340  scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1341  genStmt(env, rewriter, ej, curr + 1);
1342  endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1343  } else {
1344  genStmt(env, rewriter, ej, curr + 1);
1345  }
1346  }
1347  }
1348 
1349  // End a loop.
1350  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1351  }
1352  }
1353 
1354  // End a loop sequence.
1355  endLoopSeq(env, rewriter, exp, curr);
1356  assert(curr == env.getCurrentDepth());
1357 }
1358 
1359 /// Converts the result computed by the sparse kernel into the required form.
1360 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1361  linalg::GenericOp op = env.op();
1362  OpOperand *lhs = op.getDpsInitOperand(0);
1363  Value tensor = lhs->get();
1364  Type resType = tensor.getType();
1365  if (getSparseTensorEncoding(resType)) {
1366  // The sparse tensor rematerializes from the original sparse tensor's
1367  // underlying sparse storage format. For an insertion chain, the
1368  // tensor materializes from the chain with 'hasInserts' enabled.
1369  bool hasInserts = false;
1370  if (Value chain = env.getInsertionChain()) {
1371  hasInserts = true;
1372  tensor = chain;
1373  }
1374  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1375  } else {
1376  // To rematerialize an non-annotated tensor, simply load it
1377  // from the bufferized value.
1378  Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1379  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1380  }
1381 }
1382 
1383 //===----------------------------------------------------------------------===//
1384 // Sparsifier rewriting methods.
1385 //===----------------------------------------------------------------------===//
1386 
1387 namespace {
1388 
1389 /// Sparse rewriting rule for generic Lingalg operation.
1390 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1391 public:
1392  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1393  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1394 
1395  LogicalResult matchAndRewrite(linalg::GenericOp op,
1396  PatternRewriter &rewriter) const override {
1397  // Only accept single output operations with pure tensor semantics.
1398  if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1399  return failure();
1400 
1401  // Only accept trivial affine indices.
1403  return failure();
1404 
1405  // Only accept scheduled loops.
1406  if (!op->hasAttr("sorted")) {
1407  return rewriter.notifyMatchFailure(
1408  op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1409  "before sparsification.");
1410  }
1411 
1412  // Must have been demapped as well if the generic op is sorted.
1414 
1415  // Sets up a code generation environment.
1416  const unsigned numTensors = op->getNumOperands();
1417  const unsigned numLoops = op.getNumLoops();
1418  bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1419  // If we have indexing map like (d0) -> (0, d0), there might be more
1420  // levels then loops because of the constant index, that means we can not
1421  // use numLoops as the upper bound for ranks of all tensors.
1422  // TODO: Constant indices are currently not support on sparse tensor, but
1423  // are allowed in non-annotated dense tensor. Support it, it would be
1424  // required for sparse tensor slice rank reducing too.
1425  Level maxLvlRank = 0;
1426  for (auto operand : op.getOperands()) {
1427  if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1428  maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1429  }
1430  }
1431 
1432  // Detects sparse annotations and translates the per-level sparsity
1433  // information for all tensors to loop indices in the kernel.
1434  CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1435  if (!findSparseAnnotations(env, needIdxRed))
1436  return failure();
1437 
1438  // Only standard reduction operations (add, sub, or, xor) that can be
1439  // sparsified by merely reducing the stored values are admissible. More
1440  // elaborate reduction operations (such as mul, and, min, max) would need
1441  // to know whether implicit zeros occur as well. They can still be
1442  // implemented with a custom reduction operation, accepted here as well.
1443  if (op.getNumReductionLoops() > 0) {
1444  Operation *yield = op.getRegion().front().getTerminator();
1445  assert(isa<linalg::YieldOp>(yield));
1446  Operation *redop = yield->getOperand(0).getDefiningOp();
1447  if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1448  !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1449  !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1450  !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1451  !isa<ReduceOp>(redop)) {
1452  return failure();
1453  }
1454  }
1455 
1456  // Constructs the tensor expressions tree from `op`, returns failure if the
1457  // tree can not be built or the tensor expression is inadmissible.
1458  if (failed(env.initTensorExp()))
1459  return failure();
1460 
1461  // Recursively generates code if admissible.
1462  env.startEmit(options.sparseEmitStrategy);
1463  genBuffers(env, rewriter);
1464  // TODO: Constant affine expression should be handled differently when using
1465  // slice-based codegen, it does not matter now because we already reject the
1466  // constant expression at an earlier stage.
1467  genInitConstantDenseAddress(env, rewriter);
1468  genStmt(env, rewriter, env.getExprId(), 0);
1469  genResult(env, rewriter);
1470  return success();
1471  }
1472 
1473 private:
1474  /// Options to control sparse code generation.
1476 };
1477 
1478 } // namespace
1479 
1480 /// Populates the given patterns list with rewriting rules required for
1481 /// the sparsification of linear algebra operations.
1484  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1485 }
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)
Starts a single loop in current sequence.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)
Generates a case region in the coiterate operation.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:227
void updateValidLexInsert(Value val)
Definition: CodegenEnv.cpp:255
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:102
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:83
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
Definition: CodegenEnv.cpp:199
unsigned getCurrentDepth() const
Definition: CodegenEnv.h:115
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:107
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:95
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:82
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:72
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:204
unsigned getLoopNum() const
Definition: CodegenEnv.h:89
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:194
bool generatingSparseIterator() const
Definition: CodegenEnv.h:52
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:265
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Definition: CodegenEnv.cpp:186
void startEmit(SparseEmitStrategy emitStrategy)
Definition: CodegenEnv.cpp:61
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:111
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:81
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:213
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:133
void startValidLexInsert(Value val)
Definition: CodegenEnv.cpp:250
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:75
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:84
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:243
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
const std::vector< Value > & getValBuffer() const
Definition: LoopEmitter.h:241
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast<AffineDimExpr>...
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
Definition: LoopEmitter.h:175
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
Definition: LoopEmitter.h:227
auto getLoopIVsRange() const
Get the range of values for all induction variables.
Definition: LoopEmitter.h:161
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
Definition: LoopEmitter.h:238
A class to handle all iteration lattice operations.
Definition: Merger.h:225
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:416
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:940
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:249
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:426
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:442
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:431
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:472
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:348
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:545
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:363
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1618
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:481
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
bool isUniqueLT(LevelType lt)
Definition: Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
unsigned TensorLevel
Definition: LoopEmitter.h:26
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
Definition: CodegenUtils.h:400
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Definition: SparseTensor.h:179
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
Definition: Enums.h:412
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options for the Sparsification pass.
Definition: Passes.h:93
SparseEmitStrategy sparseEmitStrategy
Definition: Passes.h:107
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:106
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:218
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:215
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition: Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.