MLIR  19.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenEnv.h"
14 #include "Utils/CodegenUtils.h"
15 #include "Utils/LoopEmitter.h"
16 
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TensorEncoding.h"
36 #include "llvm/ADT/SmallBitVector.h"
37 
38 #include <optional>
39 
40 using namespace mlir;
41 using namespace mlir::sparse_tensor;
42 
43 //===----------------------------------------------------------------------===//
44 // Sparsifier analysis methods.
45 //===----------------------------------------------------------------------===//
46 
47 /// Returns true iff affine expression is invariant. Sets the
48 /// parameter `isCurrentLoop` when expression just became invariant.
49 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50  switch (a.getKind()) {
51  case AffineExprKind::DimId: {
52  const LoopId i = cast<AffineDimExpr>(a).getPosition();
53  if (i + 1 == curr) {
54  isCurrentLoop = true;
55  return true; // becomes invariant at current loop
56  }
57  return i < curr; // invariant when already generated
58  }
60  case AffineExprKind::Mul: {
61  auto binOp = cast<AffineBinaryOpExpr>(a);
62  return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
63  isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
64  }
65  default: {
66  assert(isa<AffineConstantExpr>(a));
67  return true;
68  }
69  }
70 }
71 
72 /// Helper method to inspect affine expressions. Rejects cases where the
73 /// same index is used more than once. Also rejects compound affine
74 /// expressions in sparse dimensions.
75 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76  LevelType lt, bool setLvlFormat = true) {
77  switch (a.getKind()) {
78  case AffineExprKind::DimId: {
79  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
80  if (!isUndefLT(merger.getLvlType(tid, idx)))
81  return false; // used more than once
82  if (setLvlFormat)
83  merger.setLevelAndType(tid, idx, lvl, lt);
84  return true;
85  }
89  assert(lt.hasDenseSemantic());
90  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
91  // We do not set dim level format for affine expression like d0 + d1 on
92  // either loop index at d0 or d1. We continue the recursion merely to
93  // check whether current affine is admissible or not.
94  return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
95  findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
96  }
97  // Falls through when it is a constant Affine
98  return true;
99  }
100  default:
101  return false;
102  }
103 }
104 
105 /// Helper method to inspect affine expressions for index variable reduction
106 /// based codegen. It finds the dependent index set for all tensor levels in the
107 /// current expression we are generating.
108 ///
109 /// For example, when handling A[i+j][j+k], we build the two way mapping in
110 /// merger between (tensor, level) pairs and their dependent index variable set:
111 /// A_0 <=> [i, j] and A_1 <=> [j, k]
112 ///
113 /// It rejects cases (returns false)
114 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
115 /// 2nd, when multiplication is used in the non-trivial index expression.
116 /// 3rd, when a constant operand is used in the non-trivial index expression.
117 ///
118 /// TODO: constant should be easy to handle.
119 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120  AffineExpr a, LevelType lt, bool isSubExp = false,
121  int64_t coefficient = 1) {
122  switch (a.getKind()) {
123  case AffineExprKind::DimId: {
124  // Only allow positive coefficients on AffineDimExpr.
125  if (coefficient <= 0)
126  return false;
127 
128  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
129  if (!isUndefLT(merger.getLvlType(tensor, idx)))
130  return false; // used more than once, e.g., A[i][i]
131 
132  // TODO: Generalizes the following two cases. A[i] (with trivial index
133  // expression) can be treated as a special affine index expression. We do
134  // not necessarily need to differentiate them.
135  if (!isSubExp) {
136  assert(coefficient == 1);
137  merger.setLevelAndType(tensor, idx, lvl, lt);
138  }
139 
140  if (isSubExp) {
141  // The current loops appears in more than one affine expressions on the
142  // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143  // used twice.
144  if (merger.hasDependentLvl(idx, tensor)) {
145  // TODO: This can be supported by coiterate slices if the loop idx is
146  // appeared on affine index for different tensor, or take slice on
147  // multiple dimensions when it is on the same tensor.
148  // E.g.,
149  // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150  // d0_1 = getNextSliceOffset t0 along lvl0
151  // d0_2 = getNextSliceOffset t1 along lvl0
152  // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153  // else increase min(d0_1, d0_2).
154  return false;
155  }
156  merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
157  }
158  return true;
159  }
161  case AffineExprKind::Mul: {
162  // TODO: Support index expression like `2 * d0`, we now only support more
163  // complicated cases like `2 * d0 + d1`.
164  if (!isSubExp)
165  return false;
166 
167  // TODO: Support Constant AffineExp for slice-based codegen
168  if (isa<AffineConstantExpr>(a))
169  llvm_unreachable("Not yet implemented");
170 
171  auto binOp = cast<AffineBinaryOpExpr>(a);
172  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173  if (isa<AffineConstantExpr>(rhs))
174  std::swap(lhs, rhs);
175  // Must be in form of `constant * d`.
176  assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177  int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
178  return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
179  }
180  case AffineExprKind::Add: {
181  auto binOp = cast<AffineBinaryOpExpr>(a);
182  return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
183  findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
184  }
185  default:
186  return false;
187  }
188 }
189 
190 /// Gets the total number of compound affine expressions in the
191 /// `getMatchingIndexingMap` for the given tensor. For the following inputs:
192 ///
193 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194 ///
195 /// Returns 1 (because the first level is compressed and its corresponding
196 /// indexing-expression is `d0 + d1`)
198  Value tensor) {
199  // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200  // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201  // However, we don't need to handle `StorageSpecifierType`, so we
202  // can use `SparseTensorType` once we guard against non-tensors.
203  const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204  if (!rtp)
205  return 0;
206  const SparseTensorType stt(rtp);
207 
208  const Level lvlRank = stt.getLvlRank();
209  const auto exprs = map.getResults();
210  assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211  "AffineMap does not have dimension-rank many results");
212  unsigned num = 0;
213  for (Level l = 0; l < lvlRank; l++) {
214  if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215  num++;
216  }
217  return num;
218 }
219 
220 /// Gets the total number of sparse levels with compound affine
221 /// expressions, summed over all operands of the `GenericOp`.
222 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223  unsigned num = 0;
224  for (OpOperand &t : op->getOpOperands())
225  num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226  t.get());
227  return num;
228 }
229 
230 // Returns true iff output has nontrivial affine indices.
231 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232  OpOperand *out = op.getDpsInitOperand(0);
233  if (getSparseTensorType(out->get()).isAllDense())
234  return false;
235  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236  out->get());
237 }
238 
239 /// Helper method to inspect sparse encodings in the tensor types.
240 /// Fills the per-dimension sparsity information for all tensors.
241 /// Returns true if the sparse annotations and affine subscript
242 /// expressions of all tensors are admissible. Returns false if
243 /// no annotations are found or inadmissible constructs occur.
244 /// We currently support two different ways to handle non-trivial index
245 /// expression on sparse tensors, and they accept different affine expressions.
246 /// When using dependent index reducton-based approach, it currently only
247 /// supports affine addition index expression.
248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249  bool annotated = false;
250  for (OpOperand &t : env.op()->getOpOperands()) {
251  const TensorId tid = env.makeTensorId(t.getOperandNumber());
252  const auto map = env.op().getMatchingIndexingMap(&t);
253  const auto enc = getSparseTensorEncoding(t.get().getType());
254  if (enc)
255  annotated = true;
256  const Level lvlRank = map.getNumResults();
257  assert(!enc || lvlRank == enc.getLvlRank());
258  assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259  // We only need to do index reduction if there is at least one
260  // non-trivial index expression on sparse levels. If all non-trivial
261  // index expression is on dense levels, we can efficiently rely on
262  // the random access to locate the element.
263  bool needIdxReduc =
264  enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265  // If then current tensor being inspected requires affine index, it need
266  // to be sliced.
267  for (Level l = 0; l < lvlRank; l++) {
268  const AffineExpr a = map.getResult(l);
269  const LevelType lt = enc.getLvlType(l);
270  if (idxReducBased && needIdxReduc) {
271  if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272  return false; // inadmissible affine expression
273  } else {
274  if (!findAffine(env.merger(), tid, l, a, lt))
275  return false; // inadmissible affine expression
276  }
277  }
278  }
279  return annotated;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Sparsifier synthesis methods (statements and expressions).
284 //===----------------------------------------------------------------------===//
285 
286 /// Local bufferization of all dense and sparse data structures.
287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288  linalg::GenericOp op = env.op();
289  Location loc = op.getLoc();
290  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291 
292  SmallVector<Range, 4> loopRange =
293  llvm::cast<linalg::LinalgOp>(op.getOperation())
294  .createLoopRanges(builder, loc);
295 
297  builder, loc,
298  /// Generates buffer for the output tensor.
299  /// Note that all sparse kernels assume that when all elements are written
300  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301  /// to all zeroes and only nonzeroes values are computed and written out.
302  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303  /// for the updates and no assumption on the original contents of the
304  /// output buffer is necessary.
305  [&op](OpBuilder &builder, Location loc, Value memref,
306  Value tensor) -> Value {
307  // Must not be a sparse tensor.
308  assert(!getSparseTensorEncoding(tensor.getType()));
309  // Two output tensor references should point to the same object.
310  OpOperand *lhs = op.getDpsInitOperand(0);
311  assert(lhs->get() == tensor);
312  // An output tensor can simply materialize from the buffer of the tensor
313  // that appears in the outs() clause. For updates, this has the
314  // advantage that only the nonzero value are involved in the
315  // computation, keeping the operation O(nnz). In all other cases, we are
316  // forced to zero out the buffer to enforce the assumption above, which
317  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318  // O(nnz) for matrices).
319  // TODO: use better analysis to avoid zeroing out the buffer?
320  bool isInit = op.isInitTensor(lhs);
321  Value init = memref;
322  if (!isInit) {
323  Value zero = constantZero(builder, loc,
324  getElementTypeOrSelf(tensor.getType()));
325  builder.create<linalg::FillOp>(loc, ValueRange{zero},
326  ValueRange{init});
327  }
328  return init;
329  },
330  [&loopRange](OpBuilder &b, Location loc, Level l) {
331  assert(l < loopRange.size());
332  return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333  });
334 }
335 
336 /// Generates index for load/store on sparse tensor.
337 static Value genIndex(CodegenEnv &env, OpOperand *t) {
338  const auto map = env.op().getMatchingIndexingMap(t);
339  const auto stt = getSparseTensorType(t->get());
340  const Level lvlRank = stt.getLvlRank();
341  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342  const AffineExpr a = map.getResult(lvlRank - 1);
343  assert(a.getKind() == AffineExprKind::DimId);
344  const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345  return env.getLoopVar(idx);
346 }
347 
348 /// Generates subscript for load/store on a dense or sparse tensor.
349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350  SmallVectorImpl<Value> &args) {
351  const Location loc = env.op().getLoc();
352  const TensorId tid = env.makeTensorId(t->getOperandNumber());
353  const auto map = env.op().getMatchingIndexingMap(t);
354  const auto stt = getSparseTensorType(t->get());
355  if (stt.hasEncoding()) {
356  // For sparse tensors we only push the last-level's position onto `args`.
357  const auto pos = env.emitter().getValPosits(tid);
358  assert(!pos.empty());
359  args.append(pos);
360  } else {
361  // For dense tensors we push all level's coordinates onto `args`.
362  const Level lvlRank = stt.getLvlRank();
363  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
364  for (Level l = 0; l < lvlRank; l++) {
365  const auto lvlExpr = map.getResult(l);
366  const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
367  args.push_back(lvlCrd);
368  }
369  }
370  return env.emitter().getValBuffer()[tid];
371 }
372 
373 /// Generates insertion code to implement dynamic tensor load.
375  OpOperand *t) {
376  linalg::GenericOp op = env.op();
377  Location loc = op.getLoc();
378  // Direct lexicographic coordinate order, tensor loads as zero.
379  if (!env.isExpand()) {
380  Type tp = getElementTypeOrSelf(t->get().getType());
381  return constantZero(builder, loc, tp);
382  }
383  // Load from expanded access pattern.
384  Value index = genIndex(env, t);
385  return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
386 }
387 
388 /// Generates insertion code to implement dynamic tensor load for reduction.
390  OpOperand *t) {
391  linalg::GenericOp op = env.op();
392  Location loc = op.getLoc();
393  Value identity = env.getCustomRedId();
394  // Direct lexicographic coordinate order, tensor loads as identity.
395  if (!env.isExpand())
396  return identity;
397  // Load from expanded access pattern if filled, identity otherwise.
398  Value values = env.getExpandValues();
399  Value filled = env.getExpandFilled();
400  Value index = genIndex(env, t);
401  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
402  Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
403  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
404 }
405 
406 static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
407  Value sparseOut, ValueRange ivs, Value v) {
408  scf::IfOp condInsert =
409  builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
410  // True branch.
411  builder.setInsertionPointToStart(condInsert.thenBlock());
412  Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
413  builder.create<scf::YieldOp>(loc, res);
414  // False branch.
415  builder.setInsertionPointToStart(condInsert.elseBlock());
416  builder.create<scf::YieldOp>(loc, sparseOut);
417  // Value assignment.
418  builder.setInsertionPointAfter(condInsert);
419  return condInsert.getResult(0);
420 }
421 
422 /// Generates insertion code to implement dynamic tensor store.
423 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
424  Value rhs) {
425  linalg::GenericOp op = env.op();
426  Location loc = op.getLoc();
427  // Direct insertion in lexicographic coordinate order.
428  if (!env.isExpand()) {
429  const LoopId numLoops = op.getRank(t);
430  // Retrieves the first `numLoop` induction variables.
431  SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
432  env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
433  Value chain = env.getInsertionChain();
434  if (env.isValidLexInsert()) {
435  // Generates runtime check for a valid lex during reduction,
436  // to avoid inserting the identity value for empty reductions.
437  // if (validLexInsert) then
438  // insert(rhs) into chain
439  // return updated chain
440  // else
441  // return unmodified chain
442  Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
443  chain, ivs, rhs);
444  env.updateInsertionChain(out);
445  } else {
446  Value sparseOut;
447  if (!hasAnySparseType(env.op().getInputs().getTypes())) {
448  // This is an all-dense -> sparse kernel, test rhs != 0 before
449  // insertion.
450  Value nz = genIsNonzero(builder, loc, rhs);
451  sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
452  } else {
453  sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
454  }
455  // Generates regular insertion chain.
456  env.updateInsertionChain(sparseOut);
457  }
458  return;
459  }
460  // Generates insertion code along expanded access pattern.
461  // if (!expFilled[i]) then
462  // expFilled[i] = true
463  // expAdded[inserts++] = i
464  // endif
465  // values[i] = rhs
466  Value values = env.getExpandValues();
467  Value filled = env.getExpandFilled();
468  Value added = env.getExpandAdded();
469  Value count = env.getExpandCount();
470  Value index = genIndex(env, t);
471  Value fval = constantI1(builder, loc, false);
472  Value tval = constantI1(builder, loc, true);
473  // If statement.
474  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
475  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
476  isFilled, fval);
477  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
478  /*else=*/true);
479  // True branch.
480  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
481  builder.create<memref::StoreOp>(loc, tval, filled, index);
482  builder.create<memref::StoreOp>(loc, index, added, count);
483  Value one = constantIndex(builder, loc, 1);
484  Value add = builder.create<arith::AddIOp>(loc, count, one);
485  builder.create<scf::YieldOp>(loc, add);
486  // False branch.
487  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
488  builder.create<scf::YieldOp>(loc, count);
489  builder.setInsertionPointAfter(ifOp);
490  // Value assignment.
491  env.updateExpandCount(ifOp.getResult(0));
492  builder.create<memref::StoreOp>(loc, rhs, values, index);
493 }
494 
495 /// Generates a load on a dense or sparse tensor.
496 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
497  // Test if the load was hoisted to a higher loop nest.
498  Value val = env.exp(exp).val;
499  if (val)
500  return val;
501  // Get tensor operand.
502  linalg::GenericOp op = env.op();
503  Location loc = op.getLoc();
504  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
505  // Fold binary-valued tensor into explicit value.
506  const auto stt = getSparseTensorType(t->get());
507  if (auto explVal = stt.getExplicitVal())
508  return genValFromAttr(builder, loc, explVal);
509  // Load during insertion.
510  if (env.isSparseOutput(t)) {
511  if (env.isCustomReduc())
512  return genInsertionLoadReduce(env, builder, t);
513  return genInsertionLoad(env, builder, t);
514  }
515  // Actual load.
516  SmallVector<Value> args;
517  Value ptr = genSubscript(env, builder, t, args);
518  return builder.create<memref::LoadOp>(loc, ptr, args);
519 }
520 
521 /// Generates a store on a dense or sparse tensor.
522 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
523  Value rhs) {
524  // Only unary and binary are allowed to return an uninitialized rhs
525  // to indicate missing output. Or otherwise a custom reduction that
526  // received no value to accumulate.
527  if (!rhs) {
528  assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
529  env.exp(exp).kind == TensorExp::Kind::kBinary ||
530  env.exp(exp).kind == TensorExp::Kind::kReduce);
531  return;
532  }
533  // Test if this is a scalarized reduction.
534  if (env.isReduc()) {
535  env.updateReduc(rhs);
536  return;
537  }
538  // Regular store.
539  linalg::GenericOp op = env.op();
540  Location loc = op.getLoc();
541  OpOperand *t = op.getDpsInitOperand(0);
542  if (!env.isSparseOutput(t)) {
543  SmallVector<Value> args;
544  Value ptr = genSubscript(env, builder, t, args);
545  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
546  return;
547  }
548  // Store during sparse insertion.
549  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
550  genInsertionStore(env, builder, t, rhs);
551  return;
552  }
553  // Select operation insertion.
554  Value chain = env.getInsertionChain();
555  scf::IfOp ifOp =
556  builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
557  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
558  // Existing value was preserved to be used here.
559  assert(env.exp(exp).val);
560  Value v0 = env.exp(exp).val;
561  genInsertionStore(env, builder, t, v0);
562  env.merger().clearExprValue(exp);
563  // Yield modified insertion chain along true branch.
564  Value mchain = env.getInsertionChain();
565  builder.create<scf::YieldOp>(op.getLoc(), mchain);
566  // Yield original insertion chain along false branch.
567  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
568  builder.create<scf::YieldOp>(loc, chain);
569  // Done with if statement.
570  env.updateInsertionChain(ifOp->getResult(0));
571  builder.setInsertionPointAfter(ifOp);
572 }
573 
574 /// Generates an invariant value.
575 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
576  return env.exp(exp).val;
577 }
578 
579 /// Semi-ring branches are simply inlined by the sparsifier. Prior
580 /// analysis has verified that all computations are "local" to the inlined
581 /// branch or otherwise invariantly defined outside the loop nest, with the
582 /// exception of index computations, which need to be relinked to actual
583 /// inlined cloned code.
584 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
585  Value e) {
586  if (auto arg = dyn_cast<BlockArgument>(e)) {
587  // Direct arguments of the original linalg op must be converted
588  // into dense tensor loads. Note that we should not encounter
589  // anything else. This needs to be verified by semi-ring ops.
590  linalg::GenericOp op = env.op();
591  if (arg.getOwner()->getParentOp() == op) {
592  const TensorId tid = env.makeTensorId(arg.getArgNumber());
593  OpOperand *t = &op->getOpOperand(tid);
594  assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
595  SmallVector<Value> args;
596  Value ptr = genSubscript(env, rewriter, t, args);
597  return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
598  }
599  } else if (Operation *def = e.getDefiningOp()) {
600  // Handle index computation.
601  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
602  return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
603  // When still defined in new body, recurse into operands.
604  if (def->getBlock() == block) {
605  rewriter.setInsertionPoint(def);
606  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
607  rewriter.modifyOpInPlace(def, [&]() {
608  def->setOperand(
609  i, relinkBranch(env, rewriter, block, def->getOperand(i)));
610  });
611  }
612  }
613  }
614  return e;
615 }
616 
617 /// Recursively generates tensor expression.
618 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
620  return Value();
621 
622  linalg::GenericOp op = env.op();
623  Location loc = op.getLoc();
624  const TensorExp &exp = env.exp(e);
625  const auto kind = exp.kind;
626  if (kind == TensorExp::Kind::kTensor)
627  return genTensorLoad(env, rewriter, e);
628  if (kind == TensorExp::Kind::kInvariant)
629  return genInvariantValue(env, e);
630  if (kind == TensorExp::Kind::kLoopVar)
631  return env.getLoopVar(exp.loop);
632 
633  if (kind == TensorExp::Kind::kReduce)
634  env.startCustomReduc(e); // enter custom
635 
636  // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
637  // based on the type of the other operand.
638  Value v0, v1;
641  v1 = genExp(env, rewriter, exp.children.e1);
642  v0 = constantZero(rewriter, loc, v1.getType());
645  v0 = genExp(env, rewriter, exp.children.e0);
646  v1 = constantZero(rewriter, loc, v0.getType());
647  } else {
648  v0 = genExp(env, rewriter, exp.children.e0);
649  v1 = genExp(env, rewriter, exp.children.e1);
650  }
651 
652  Value ee;
653  if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
654  // custom reduce did not receive a value
655  } else {
656  ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
657  if (ee &&
660  kind == TensorExp::Kind::kReduce ||
661  kind == TensorExp::Kind::kSelect)) {
662  OpBuilder::InsertionGuard guard(rewriter);
663  ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
664  }
665  }
666 
667  if (kind == TensorExp::Kind::kReduce)
668  env.endCustomReduc(); // exit custom
669 
670  if (kind == TensorExp::Kind::kSelect)
671  env.merger().setExprValue(e, v0); // Preserve value for later use.
672 
673  return ee;
674 }
675 
676 /// Hoists loop invariant tensor loads for which indices have been exhausted.
677 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
678  LoopId curr, bool isStart) {
680  return;
681  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
682  // Inspect tensor indices.
683  linalg::GenericOp op = env.op();
684  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
685  const auto map = op.getMatchingIndexingMap(&t);
686  const auto stt = getSparseTensorType(t.get());
687  const Level lvlRank = stt.getLvlRank();
688  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
689  bool isCurrentLoop = curr == 0; // for scalar tensors
690  for (Level l = 0; l < lvlRank; l++) {
691  const AffineExpr a = map.getResult(l);
692  if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
693  return; // still in play
694  }
695  // All exhausted at current level.
696  if (!isCurrentLoop)
697  return;
698  // Generate code for a scalarized reduction or invariant. Note that
699  // because custom reduction lhs may occur several times in the IR,
700  // we have a built-in safety for only initializing and wrapping-up
701  // the scalarized reduction once.
702  OpOperand *lhs = op.getDpsInitOperand(0);
703  if (lhs == &t) {
704  // Start or end a scalarized reduction.
705  if (isStart) {
706  if (env.isCustomReduc()) {
707  if (!env.isReduc())
708  env.startReduc(exp, env.getCustomRedId());
709  } else {
710  env.startReduc(exp, genTensorLoad(env, builder, exp));
711  }
712  if (env.hasSparseOutput())
714  constantI1(builder, env.op().getLoc(), false));
715  } else {
716  if (!env.isCustomReduc() || env.isReduc())
717  genTensorStore(env, builder, exp, env.endReduc());
718  if (env.hasSparseOutput())
719  env.endValidLexInsert();
720  }
721  } else {
722  // Start or end loop invariant hoisting of a tensor load.
723  if (isStart) {
724  env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
725  } else {
726  env.merger().clearExprValue(exp);
727  }
728  }
729  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
730  env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
731  env.exp(exp).kind != TensorExp::Kind::kSynZero) {
732  // Traverse into the binary operations. Note that we only hoist
733  // tensor loads, since subsequent MLIR/LLVM passes know how to
734  // deal with all other kinds of derived loop invariants.
735  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
736  env.startCustomReduc(exp); // enter custom
737  const ExprId e0 = env.exp(exp).children.e0;
738  const ExprId e1 = env.exp(exp).children.e1;
739  genInvariants(env, builder, e0, curr, isStart);
740  genInvariants(env, builder, e1, curr, isStart);
741  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
742  env.endCustomReduc(); // exit custom
743  }
744 }
745 
746 /// Generates an expanded access pattern in innermost dimension.
747 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
748  bool isStart) {
749  linalg::GenericOp op = env.op();
750  OpOperand *lhs = op.getDpsInitOperand(0);
751  if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
752  return; // not needed at current level
753  assert(!env.isReduc());
754  // Generate start or end of an expanded access pattern. Note that because
755  // an expansion does not rely on the ongoing contents of the sparse storage
756  // scheme, we can use the original tensor as incoming SSA value (which
757  // simplifies codegen a bit). If expansion on the actual contents is ever
758  // needed, we will need to use the SSA value in the insertion chain instead.
759  Value tensor = lhs->get();
760  Location loc = op.getLoc();
761  if (isStart) {
762  auto dynShape = {ShapedType::kDynamic};
763  Type etp = cast<ShapedType>(tensor.getType()).getElementType();
764  Type t1 = MemRefType::get(dynShape, etp);
765  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
766  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
767  Type t4 = builder.getIndexType();
768  auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
769  assert(r.getNumResults() == 4);
770  env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
771  r.getResult(3));
772  } else {
773  SmallVector<Value> indices;
774  for (LoopId i = 0; i < curr; i++)
775  indices.push_back(env.emitter().getLoopIV(i));
776  Value values = env.getExpandValues();
777  Value filled = env.getExpandFilled();
778  Value added = env.getExpandAdded();
779  Value count = env.getExpandCount();
780  Value chain = env.getInsertionChain();
781  Value compress = builder.create<CompressOp>(loc, values, filled, added,
782  count, chain, indices);
783  env.updateInsertionChain(compress);
784  env.endExpand();
785  }
786 }
787 
788 /// Returns parallelization strategy. Any implicit loop in the Linalg
789 /// operation that is marked "parallel" is a candidate. Whether it is actually
790 /// converted to a parallel operation depends on the requested strategy.
791 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
792  // Reject parallelization of sparse output.
793  if (env.hasSparseOutput())
794  return false;
795  // Parallel loops on tensor expansion can cause data races.
796  if (env.isExpand())
797  return false;
798  // Inspect strategy.
799  switch (env.options().parallelizationStrategy) {
801  return false;
803  return isOuter && !isSparse;
805  return isOuter;
807  return !isSparse;
809  return true;
810  }
811  llvm_unreachable("unexpected parallelization strategy");
812 }
813 
814 /// Whether or not the current loop being generated should be parallized (if
815 /// possible) according to the configuration.
816 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
817  ArrayRef<TensorLevel> tidLvls) {
818  linalg::GenericOp op = env.op();
819  auto iteratorTypes = op.getIteratorTypesArray();
820  bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
821  // Queries the LT based on the tensor and loop id, as requested by
822  // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
823  // should be consistent with the LT indexed by <TensorId, Level>.
824  const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
825  return lt.hasSparseSemantic();
826  });
827  return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
828 }
829 
830 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
831 /// can either be a for loop or while loop depending on whether there is at most
832 /// one sparse level in the list.
834  ArrayRef<TensorLevel> tidLvls,
835  bool tryParallel, bool needsUniv) {
836  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
837  // Construct while-loop with a parameter for each index.
838  return env.emitter().enterCoIterationOverTensorsAtLvls(
839  builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
840  });
841  assert(loop);
842  return loop;
843 }
844 
845 /// Generates a for-loop or a while-loop, depending on whether it implements
846 /// singleton iteration or co-iteration over the given conjunction.
847 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
848  bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
849  bool tryParallel = shouldTryParallize(env, curr, tidLvls);
850  return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
851 }
852 
853 /// Generates the induction structure for a while-loop.
854 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
855  bool needsUniv) {
856  Location loc = env.op().getLoc();
857  // Finalize each else branch of all if statements.
858  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
859  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
860  builder.getInsertionBlock()->getParentOp())) {
861  // Break on IfOp for slicing filtering.
862  if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
863  StringAttr::get(ifOp->getContext(), "slice"))
864  break;
865 
866  unsigned y = 0;
867  SmallVector<Value> yields;
868  if (env.isReduc()) {
869  yields.push_back(env.getReduc());
870  env.updateReduc(ifOp.getResult(y++));
871  if (env.isValidLexInsert()) {
872  yields.push_back(env.getValidLexInsert());
873  env.updateValidLexInsert(ifOp.getResult(y++));
874  }
875  }
876  if (env.isExpand()) {
877  yields.push_back(env.getExpandCount());
878  env.updateExpandCount(ifOp->getResult(y++));
879  }
880  if (env.getInsertionChain()) {
881  yields.push_back(env.getInsertionChain());
882  env.updateInsertionChain(ifOp->getResult(y++));
883  }
884  assert(y == yields.size());
885  builder.create<scf::YieldOp>(loc, yields);
886  builder.setInsertionPointAfter(ifOp);
887  }
888  }
889  // No need to set the insertion point here as LoopEmitter keeps track of the
890  // basic block where scf::Yield should be inserted.
891 }
892 
893 /// Generates a single if-statement within a while-loop.
894 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
895  LatPointId p) {
896  Location loc = env.op().getLoc();
897  SmallVector<Type> types;
898  Value cond;
900  p, /*simple=*/true,
901  [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
902  bool isIdxRed) {
903  if (isIdxRed) {
904  // Since there is no 1:1 mapping from loop to level (multiple loops
905  // are required to resolve one level with non-trivial index
906  // expression), we need to reconstruct the tensor level types if this
907  // loop requires index reduction condition.
908  assert(lvl.has_value() && isUndefLT(lt));
909  auto stt = getSparseTensorType(env.op().getInputs()[tid]);
910  lt = stt.getLvlType(*lvl);
911  }
912  assert(curr == env.merger().loop(b));
913  Value clause;
914  if (lt.hasSparseSemantic()) {
915  assert(lvl.has_value());
916  const Value crd = env.emitter().getCoord(tid, *lvl);
917  const Value lvar = env.getLoopVar(curr);
918  clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
919  crd, lvar);
920  } else {
921  assert(lt.hasDenseSemantic() || isUndefLT(lt));
922  clause = constantI1(builder, loc, true);
923  }
924  cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
925  });
926  if (env.isReduc()) {
927  types.push_back(env.getReduc().getType());
928  if (env.isValidLexInsert())
929  types.push_back(env.getValidLexInsert().getType());
930  }
931  if (env.isExpand())
932  types.push_back(builder.getIndexType());
933  if (env.getInsertionChain())
934  types.push_back(env.getInsertionChain().getType());
935  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
936  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
937  return ifOp;
938 }
939 
940 /// Generates end of true branch of if-statement within a while-loop.
941 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
942  Value redInput, Value cntInput, Value insInput,
943  Value validIns) {
944  SmallVector<Value> operands;
945  if (env.isReduc()) {
946  operands.push_back(env.getReduc());
947  env.updateReduc(redInput);
948  if (env.isValidLexInsert()) {
949  // Any overlapping indices during a reduction creates a valid lex insert.
950  operands.push_back(constantI1(builder, env.op().getLoc(), true));
951  env.updateValidLexInsert(validIns);
952  }
953  }
954  if (env.isExpand()) {
955  operands.push_back(env.getExpandCount());
956  env.updateExpandCount(cntInput);
957  }
958  if (env.getInsertionChain()) {
959  operands.push_back(env.getInsertionChain());
960  env.updateInsertionChain(insInput);
961  }
962  if (!operands.empty())
963  builder.create<scf::YieldOp>(env.op().getLoc(), operands);
964  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // Sparsifier synthesis methods (loop sequence).
969 //===----------------------------------------------------------------------===//
970 
972  CodegenEnv &env, LatPointId li, LoopId curr,
973  llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
974  const BitVector &simple = env.lat(li).simple;
975  const TensorId outTid = env.merger().getOutTensorID();
976  const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
977 
978  unsigned numloopCond = 0;
979  bool hasNonUnique = false;
981  li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
982  LevelType lt, bool isIdxReduc) {
983  if (simple[b]) {
984  if (isIdxReduc) {
985  callback(env.makeTensorLevel(tid, *lvl), nullptr);
986  numloopCond++;
987  return;
988  }
989  if (isUndefLT(lt)) {
990  // An undefined lt in the lattices, we probably mean to
991  // generate a dense loop according to the synthetic tensor (for
992  // invariants and sparse output tensor).
993  if (env.merger().getSynTensorID() == tid) {
994  // Coiterating with an invariant
995  // e.g., out = prod(in[i][j] op invariant);
996  // or a broadcast
997  // e.g., out[i][j] = in[i] (j is undef for input)
998  //
999  // The level of the synthetic tensor is the current loop depth;
1000  // the rank of the synthetic tensor equals to number of loops.
1001  assert(curr == env.getCurrentDepth());
1002  lvl = curr;
1003  } else if (!lvl) {
1004  // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1005  return;
1006  }
1007  }
1008  hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1009  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1010  numloopCond++;
1011  } else if (lt.hasDenseSemantic() || isIdxReduc) {
1012  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1013  } else {
1014  assert(isUndefLT(lt));
1015  linalg::GenericOp op = env.op();
1016  if (tid >= op.getNumDpsInputs())
1017  // We only handle affine expression on input tensors (for now).
1018  return;
1019  OpOperand *operand = &op->getOpOperand(tid);
1020  const auto stt = getSparseTensorType(operand->get());
1021  // Non-annotated dense tensors requires no special handling.
1022  if (!stt.hasEncoding())
1023  return;
1024 
1025  ArrayRef<AffineExpr> affines =
1026  op.getMatchingIndexingMap(operand).getResults();
1027  const Level lvlRank = stt.getLvlRank();
1028  assert(affines.size() == static_cast<size_t>(lvlRank));
1029  for (Level l = 0; l < lvlRank; l++) {
1030  AffineExpr exp = affines[l];
1031  // Skip simple affine expression and non-dense levels (which
1032  // have their own filter loop).
1033  LevelType lt = stt.getLvlType(l);
1034  if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1035  continue;
1036 
1037  // Constant affine expression are handled in genLoop.
1038  if (!isa<AffineConstantExpr>(exp)) {
1039  bool isCurrentLoop = false;
1040  assert(curr == env.getCurrentDepth());
1041  if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1042  isCurrentLoop) {
1043  // If the compound affine is invariant and we are right at the
1044  // level. We need to generate the address according to the
1045  // affine expression. This is also the best place we can do it
1046  // to avoid putting it inside inner loops.
1047  callback(env.makeTensorLevel(tid, l), exp);
1048  }
1049  }
1050  }
1051  }
1052  });
1053 
1054  if (isDenseLT(env.lt(outTid, curr))) {
1055  auto stt = getSparseTensorType(env.op().getOutputs().front());
1056  // Note that we generate dense indices of the output tensor unconditionally,
1057  // since they may not appear in the lattice, but may be needed for
1058  // linearized env.
1059  // TODO: we should avoid introducing corner cases for all-dense sparse
1060  // tensors.
1061  if (stt.hasEncoding() && stt.isAllDense())
1062  callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1063  }
1064 
1065  if (numloopCond == 0) {
1066  // Corner cases where the loop bound is defined by a *unused* operand, in
1067  // this case, we just generate a dense "fake" loop by iterating over the
1068  // synthetic tensor.
1069  callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1070  numloopCond++;
1071  }
1072  // If we just need to one loop conditions and the conditions is not imposed on
1073  // non-unique level, the loop can be generated by a for loop.
1074  return numloopCond == 1 && !hasNonUnique;
1075 }
1076 
1077 /// Starts a loop sequence at given level. Returns true if
1078 /// the universal loop index must be maintained at this level.
1079 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1080  LoopId curr, LatSetId lts) {
1081  assert(!env.getLoopVar(curr));
1082  // Emit invariants at this loop sequence level.
1083  genInvariants(env, builder, exp, curr, /*isStart=*/true);
1084  // Emit access pattern expansion for sparse tensor output.
1085  genExpand(env, builder, curr, /*isStart=*/true);
1086  // Emit further initialization at this loop sequence level.
1087  const LatPointId l0 = env.set(lts)[0];
1088 
1089  SmallVector<TensorLevel> tidLvls;
1090  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1091  // TODO: remove this! The same tensor level might be added for multiple
1092  // times due to the special handling for all-dense "sparse" output tensor
1093  // (see L1038).
1094  if (llvm::find(tidLvls, tl) != tidLvls.end())
1095  return;
1096  tidLvls.emplace_back(tl);
1097  });
1098 
1099  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1100 
1101  // Maintain the universal index only if it is actually
1102  // consumed by a subsequent lattice point.
1103  for (const LatPointId li : env.set(lts).drop_front())
1104  if (!env.merger().hasAnySparse(env.lat(li).simple))
1105  return true;
1106 
1107  return false;
1108 }
1109 
1110 // Generates dense affine address for encoding.
1112  OpBuilder &builder, TensorId tid,
1113  Level startLvl) {
1114  // TODO: Handle affine expression on output tensor.
1115  linalg::GenericOp op = env.op();
1116  assert(tid < op.getNumDpsInputs());
1117  OpOperand *input = op.getDpsInputOperands()[tid];
1118  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1119  const auto enc = getSparseTensorEncoding(input->get().getType());
1120  if (enc) {
1121  const Location loc = op.getLoc();
1122  const TensorId tid = env.makeTensorId(input->getOperandNumber());
1123  const Level lvlRank = enc.getLvlRank();
1124  assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1125  for (Level l = startLvl; l < lvlRank; l++) {
1126  AffineExpr lvlExpr = lvlExprs[l];
1127  if (enc.getLvlType(l).hasDenseSemantic() &&
1128  isa<AffineConstantExpr>(lvlExpr))
1130  builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1131  else
1132  return; // break on first non-dense non-constant level
1133  }
1134  }
1135 }
1136 
1137 // We can generate address for constant affine expression before any loops
1138 // starting from the first level as they do not depend on anything.
1139 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1140 // levels can be determined before loops.
1142  RewriterBase &rewriter) {
1143  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1144  genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1145 }
1146 
1147 /// Returns true if the lattice bit can be iterated by a for loop.
1149  CodegenEnv &env, LatPointId li, LoopId curr,
1151  SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1152  return getAllTidLvlsInLatPoints(env, li, curr,
1153  [&](TensorLevel tl, AffineExpr exp) {
1154  if (exp)
1155  affineTidLvls.emplace_back(tl, exp);
1156  else
1157  tidLvls.emplace_back(tl);
1158  });
1159 }
1160 
1161 /// Starts a single loop in current sequence.
1162 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1163  OpBuilder &builder, LoopId curr,
1164  LatPointId li, bool needsUniv) {
1165  // The set of tensors + lvls to generate loops on
1166  SmallVector<TensorLevel> tidLvls;
1167 
1168  // The set of dense tensors with non-trivial affine expression that just
1169  // becomes invariant and the address are generated at the current level.
1171  bool isSingleCond =
1172  translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1173 
1174  // Emit the for/while-loop control.
1175  Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
1176  Location loc = env.op().getLoc();
1177  for (auto [tidLvl, exp] : affineTidLvls) {
1178  env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1179  }
1180 
1181  // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1182  // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1183  // on constant affines expression may now be determined.
1184  auto allTidLvls =
1185  llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1186  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1187  if (tid != env.merger().getOutTensorID() &&
1188  tid != env.merger().getSynTensorID())
1189  genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1190  }
1191 
1192  return std::make_pair(loop, isSingleCond);
1193 }
1194 
1195 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1196 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1197  LatPointId li, bool needsUniv, bool isSingleCond) {
1198  // Either a for-loop or a while-loop that iterates over a slice.
1199  if (isSingleCond) {
1200  // Any iteration creates a valid lex insert.
1201  if (env.isReduc() && env.isValidLexInsert())
1202  env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1203  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1204  // End a while-loop.
1205  finalizeWhileOp(env, rewriter, needsUniv);
1206  } else {
1207  needsUniv = false;
1208  }
1209  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1210  env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1211  return std::nullopt;
1212  });
1213  return needsUniv;
1214 }
1215 
1216 /// Ends a loop sequence at given level.
1217 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1218  unsigned at) {
1219  assert(!env.getLoopVar(at));
1220  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1221  // Unmark bookkeeping of invariants and loop index.
1222  genInvariants(env, builder, exp, at, /*isStart=*/false);
1223  // Finalize access pattern expansion for sparse tensor output.
1224  genExpand(env, builder, at, /*isStart=*/false);
1225 }
1226 
1227 /// Recursively generates code while computing iteration lattices in order
1228 /// to manage the complexity of implementing co-iteration over unions
1229 /// and intersections of sparse iterations spaces.
1230 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1231  LoopId curr) {
1232  assert(curr == env.getCurrentDepth());
1233 
1234  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1235  if (curr == env.getLoopNum()) {
1236  Value rhs = genExp(env, rewriter, exp);
1237  genTensorStore(env, rewriter, exp, rhs);
1238  return;
1239  }
1240 
1241  // Construct iteration lattices for current loop index.
1242  const LatSetId lts =
1243  env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1244 
1245  // Start a loop sequence.
1246  bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1247 
1248  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1249  // We cannot change this to `for (const LatPointId li : env.set(lts))`
1250  // because the loop body causes data-movement which invalidates
1251  // the iterator.
1252  const unsigned lsize = env.set(lts).size();
1253  for (unsigned i = 0; i < lsize; i++) {
1254  const LatPointId li = env.set(lts)[i];
1255  // Start a loop.
1256  auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
1257 
1258  // Visit all lattices points with Li >= Lj to generate the
1259  // loop-body, possibly with if statements for coiteration.
1260  Value redInput = env.getReduc();
1261  Value cntInput = env.getExpandCount();
1262  Value insInput = env.getInsertionChain();
1263  Value validIns = env.getValidLexInsert();
1264  // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1265  // because the loop body causes data-movement which invalidates the
1266  // iterator.
1267  for (unsigned j = 0; j < lsize; j++) {
1268  const LatPointId lj = env.set(lts)[j];
1269  const ExprId ej = env.lat(lj).exp;
1270  if (li == lj || env.merger().latGT(li, lj)) {
1271  // Recurse into body of each branch.
1272  if (!isSingleCond) {
1273  scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1274  genStmt(env, rewriter, ej, curr + 1);
1275  endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1276  } else {
1277  genStmt(env, rewriter, ej, curr + 1);
1278  }
1279  }
1280  }
1281 
1282  // End a loop.
1283  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1284  }
1285 
1286  // End a loop sequence.
1287  endLoopSeq(env, rewriter, exp, curr);
1288  assert(curr == env.getCurrentDepth());
1289 }
1290 
1291 /// Converts the result computed by the sparse kernel into the required form.
1292 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1293  linalg::GenericOp op = env.op();
1294  OpOperand *lhs = op.getDpsInitOperand(0);
1295  Value tensor = lhs->get();
1296  Type resType = tensor.getType();
1297  if (getSparseTensorEncoding(resType)) {
1298  // The sparse tensor rematerializes from the original sparse tensor's
1299  // underlying sparse storage format. For an insertion chain, the
1300  // tensor materializes from the chain with 'hasInserts' enabled.
1301  bool hasInserts = false;
1302  if (Value chain = env.getInsertionChain()) {
1303  hasInserts = true;
1304  tensor = chain;
1305  }
1306  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1307  } else {
1308  // To rematerialize an non-annotated tensor, simply load it
1309  // from the bufferized value.
1310  Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1311  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1312  }
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // Sparsifier rewriting methods.
1317 //===----------------------------------------------------------------------===//
1318 
1319 namespace {
1320 
1321 /// Sparse rewriting rule for generic Lingalg operation.
1322 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1323 public:
1324  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1325  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1326 
1327  LogicalResult matchAndRewrite(linalg::GenericOp op,
1328  PatternRewriter &rewriter) const override {
1329  // Only accept single output operations with pure tensor semantics.
1330  if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1331  return failure();
1332 
1333  // Only accept trivial affine indices.
1335  return failure();
1336 
1337  // Only accept scheduled loops.
1338  if (!op->hasAttr("sorted")) {
1339  return rewriter.notifyMatchFailure(
1340  op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1341  "before sparsification.");
1342  }
1343 
1344  // Must have been demapped as well if the generic op is sorted.
1346 
1347  // Sets up a code generation environment.
1348  const unsigned numTensors = op->getNumOperands();
1349  const unsigned numLoops = op.getNumLoops();
1350  bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1351  // If we have indexing map like (d0) -> (0, d0), there might be more
1352  // levels then loops because of the constant index, that means we can not
1353  // use numLoops as the upper bound for ranks of all tensors.
1354  // TODO: Constant indices are currently not support on sparse tensor, but
1355  // are allowed in non-annotated dense tensor. Support it, it would be
1356  // required for sparse tensor slice rank reducing too.
1357  Level maxLvlRank = 0;
1358  for (auto operand : op.getOperands()) {
1359  if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1360  maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1361  }
1362  }
1363 
1364  // Detects sparse annotations and translates the per-level sparsity
1365  // information for all tensors to loop indices in the kernel.
1366  CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1367  if (!findSparseAnnotations(env, needIdxRed))
1368  return failure();
1369 
1370  // Only standard reduction operations (add, sub, or, xor) that can be
1371  // sparsified by merely reducing the stored values are admissible. More
1372  // elaborate reduction operations (such as mul, and, min, max) would need
1373  // to know whether implicit zeros occur as well. They can still be
1374  // implemented with a custom reduction operation, accepted here as well.
1375  if (op.getNumReductionLoops() > 0) {
1376  Operation *yield = op.getRegion().front().getTerminator();
1377  assert(isa<linalg::YieldOp>(yield));
1378  Operation *redop = yield->getOperand(0).getDefiningOp();
1379  if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1380  !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1381  !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1382  !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1383  !isa<ReduceOp>(redop)) {
1384  return failure();
1385  }
1386  }
1387 
1388  // Constructs the tensor expressions tree from `op`, returns failure if the
1389  // tree can not be built or the tensor expression is inadmissible.
1390  if (failed(env.initTensorExp()))
1391  return failure();
1392 
1393  // Recursively generates code if admissible.
1394  env.startEmit(options.sparseEmitStrategy);
1395  genBuffers(env, rewriter);
1396  // TODO: Constant affine expression should be handled differently when using
1397  // slice-based codegen, it does not matter now because we already reject the
1398  // constant expression at an earlier stage.
1399  genInitConstantDenseAddress(env, rewriter);
1400  genStmt(env, rewriter, env.getExprId(), 0);
1401  genResult(env, rewriter);
1402  return success();
1403  }
1404 
1405 private:
1406  /// Options to control sparse code generation.
1408 };
1409 
1410 } // namespace
1411 
1412 /// Populates the given patterns list with rewriting rules required for
1413 /// the sparsification of linear algebra operations.
1415  RewritePatternSet &patterns, const SparsificationOptions &options) {
1416  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1417 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, bool needsUniv)
Starts a single loop in current sequence.
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static Operation * genCoIteration(CodegenEnv &env, OpBuilder &builder, ArrayRef< TensorLevel > tidLvls, bool tryParallel, bool needsUniv)
Emit a loop to coiterate over the list of tensor levels.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:228
void updateValidLexInsert(Value val)
Definition: CodegenEnv.cpp:256
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:103
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:79
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
Definition: CodegenEnv.cpp:200
unsigned getCurrentDepth() const
Definition: CodegenEnv.h:111
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:103
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:91
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:78
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:68
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:205
unsigned getLoopNum() const
Definition: CodegenEnv.h:85
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:195
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:266
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Definition: CodegenEnv.cpp:187
void startEmit(SparseEmitStrategy emitStrategy)
Definition: CodegenEnv.cpp:62
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:107
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:77
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:214
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:129
void startValidLexInsert(Value val)
Definition: CodegenEnv.cpp:251
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:71
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:80
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:234
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
const std::vector< Value > & getValBuffer() const
Definition: LoopEmitter.h:232
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
Definition: LoopEmitter.h:171
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
Definition: LoopEmitter.h:223
auto getLoopIVsRange() const
Get the range of values for all induction variables.
Definition: LoopEmitter.h:157
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
Definition: LoopEmitter.h:229
A class to handle all iteration lattice operations.
Definition: Merger.h:225
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:416
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:940
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:249
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:426
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:442
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:431
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:472
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:348
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:363
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1618
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:481
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
bool isUniqueLT(LevelType lt)
Definition: Enums.h:424
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
unsigned TensorLevel
Definition: LoopEmitter.h:26
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:359
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
Definition: CodegenUtils.h:403
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Definition: SparseTensor.h:106
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
Definition: Enums.h:408
bool isDenseLT(LevelType lt)
Definition: Enums.h:409
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options for the Sparsification pass.
Definition: Passes.h:97
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:110
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:218
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:215
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition: Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.