MLIR  18.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenEnv.h"
14 #include "CodegenUtils.h"
15 #include "LoopEmitter.h"
16 
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TensorEncoding.h"
36 #include "llvm/ADT/SmallBitVector.h"
37 #include <optional>
38 
39 using namespace mlir;
40 using namespace mlir::sparse_tensor;
41 
42 //===----------------------------------------------------------------------===//
43 // Sparsifier analysis methods.
44 //===----------------------------------------------------------------------===//
45 
46 // TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
47 // and those letters are too easy to confuse visually. We should switch
48 // to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
49 // (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
50 
51 /// Determines if affine expression is invariant.
52 static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
53  bool &isAtLoop) {
54  switch (a.getKind()) {
55  case AffineExprKind::DimId: {
56  const LoopId i = cast<AffineDimExpr>(a).getPosition();
57  if (i == ldx) {
58  isAtLoop = true;
59  // Must be invariant if we are at the given loop.
60  return true;
61  }
62  // The DimExpr is invariant the loop has already been generated.
63  return i < loopDepth;
64  }
66  case AffineExprKind::Mul: {
67  auto binOp = cast<AffineBinaryOpExpr>(a);
68  return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
69  isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
70  }
71  default: {
72  assert(isa<AffineConstantExpr>(a));
73  return true;
74  }
75  }
76 }
77 
78 /// Helper method to inspect affine expressions. Rejects cases where the
79 /// same index is used more than once. Also rejects compound affine
80 /// expressions in sparse dimensions.
81 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
82  LevelType lt, bool setLvlFormat = true) {
83  switch (a.getKind()) {
84  case AffineExprKind::DimId: {
85  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
86  if (!isUndefLT(merger.getLvlType(tid, idx)))
87  return false; // used more than once
88 
89  if (setLvlFormat)
90  merger.setLevelAndType(tid, idx, lvl, lt);
91  return true;
92  }
96  assert(isDenseLT(lt));
97  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
98  // We do not set dim level format for affine expression like d0 + d1 on
99  // either loop index at d0 or d1.
100  // We continue the recursion merely to check whether current affine is
101  // admissible or not.
102  return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
103  findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
104  }
105  // Falls through when it is a constant Affine
106  return true;
107  }
108  default:
109  return false;
110  }
111 }
112 
113 /// Helper method to inspect affine expressions for index variable reduction
114 /// based codegen. It finds the dependent index set for all tensor levels in the
115 /// current expression we are generating.
116 ///
117 /// For example, when handling A[i+j][j+k], we build the two way mapping in
118 /// merger between (tensor, level) pairs and their dependent index variable set:
119 /// A_0 <=> [i, j] and A_1 <=> [j, k]
120 ///
121 /// It rejects cases (returns false)
122 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
123 /// 2nd, when multiplication is used in the non-trivial index expression.
124 /// 3rd, when a constant operand is used in the non-trivial index expression.
125 ///
126 /// TODO: constant should be easy to handle.
127 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
128  AffineExpr a, LevelType lt, bool isSubExp = false,
129  int64_t coefficient = 1) {
130  switch (a.getKind()) {
131  case AffineExprKind::DimId: {
132  // Only allow positive coefficients on AffineDimExpr.
133  if (coefficient <= 0)
134  return false;
135 
136  const LoopId ldx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
137  if (!isUndefLT(merger.getLvlType(tensor, ldx)))
138  return false; // used more than once, e.g., A[i][i]
139 
140  // TODO: Generalizes the following two cases. A[i] (with trivial index
141  // expression) can be treated as a special affine index expression. We do
142  // not necessarily need to differentiate them.
143  if (!isSubExp) {
144  assert(coefficient == 1);
145  merger.setLevelAndType(tensor, ldx, lvl, lt);
146  }
147 
148  if (isSubExp) {
149  // The current loops appears in more than one affine expressions on the
150  // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
151  // used twice.
152  if (merger.hasDependentLvl(ldx, tensor)) {
153  // TODO: This can be supported by coiterate slices if the loop idx is
154  // appeared on affine index for different tensor, or take slice on
155  // multiple dimensions when it is on the same tensor.
156  // E.g.,
157  // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
158  // d0_1 = getNextSliceOffset t0 along lvl0
159  // d0_2 = getNextSliceOffset t1 along lvl0
160  // if d0_1 == d0_2 then d0 = d0_1 = d0_1
161  // else increase min(d0_1, d0_2).
162  return false;
163  }
164  merger.setLoopDependentTensorLevel(ldx, tensor, lvl, lt, coefficient);
165  }
166  return true;
167  }
169  case AffineExprKind::Mul: {
170  // TODO: Support index expression like `2 * d0`, we now only support more
171  // complicated cases like `2 * d0 + d1`.
172  if (!isSubExp)
173  return false;
174 
175  // TODO: Support Constant AffineExp for slice-based codegen
176  if (isa<AffineConstantExpr>(a))
177  llvm_unreachable("Not yet implemented");
178 
179  auto binOp = cast<AffineBinaryOpExpr>(a);
180  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
181  if (isa<AffineConstantExpr>(rhs))
182  std::swap(lhs, rhs);
183  // Must be in form of `constant * d`.
184  assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
185  int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
186  return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
187  }
188  case AffineExprKind::Add: {
189  auto binOp = cast<AffineBinaryOpExpr>(a);
190  return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
191  findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
192  }
193  default:
194  return false;
195  }
196 }
197 
198 /// Get the total number of compound affine expressions in the
199 /// `getMatchingIndexingMap` for the given tensor. For the following inputs:
200 ///
201 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
202 ///
203 /// Returns 1 (because the first level is compressed and its corresponding
204 /// indexing-expression is `d0 + d1`)
206  Value tensor) {
207  // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
208  // we can't use `getRankedTensorType`/`getSparseTensorType` here.
209  // However, we don't need to handle `StorageSpecifierType`, so we
210  // can use `SparseTensorType` once we guard against non-tensors.
211  const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
212  if (!rtp)
213  return 0;
214  const SparseTensorType stt(rtp);
215 
216  const Level lvlRank = stt.getLvlRank();
217  const auto exprs = map.getResults();
218  assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
219  "AffineMap does not have dimension-rank many results");
220  unsigned num = 0;
221  for (Level l = 0; l < lvlRank; l++) {
222  if (!isa<AffineDimExpr>(exprs[l]) && !stt.isDenseLvl(l))
223  num++;
224  }
225  return num;
226 }
227 
228 /// Get the total number of sparse levels with compound affine
229 /// expressions, summed over all operands of the `GenericOp`.
230 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
231  unsigned num = 0;
232  for (OpOperand &t : op->getOpOperands())
233  num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
234  t.get());
235  return num;
236 }
237 
238 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
239  OpOperand *out = op.getDpsInitOperand(0);
240  if (getSparseTensorType(out->get()).isAllDense())
241  return false;
242  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
243  out->get());
244 }
245 
246 /// Helper method to inspect sparse encodings in the tensor types.
247 /// Fills the per-dimension sparsity information for all tensors.
248 /// Returns true if the sparse annotations and affine subscript
249 /// expressions of all tensors are admissible. Returns false if
250 /// no annotations are found or inadmissible constructs occur.
251 /// We currently support two different ways to handle non-trivial index
252 /// expression on sparse tensors, and they accept different affine expressions.
253 /// When using dependent index reducton-based approach, it currently only
254 /// supports affine addition index expression.
255 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
256  bool annotated = false;
257  for (OpOperand &t : env.op()->getOpOperands()) {
258  const TensorId tid = env.makeTensorId(t.getOperandNumber());
259  const auto map = env.op().getMatchingIndexingMap(&t);
260  const auto enc = getSparseTensorEncoding(t.get().getType());
261  if (enc)
262  annotated = true;
263 
264  const Level lvlRank = map.getNumResults();
265  assert(!enc || lvlRank == enc.getLvlRank());
266  assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
267 
268  // We only need to do index reduction if there is at least one non-trivial
269  // index expression on sparse levels.
270  // If all non-trivial index expression is on dense levels, we can
271  // efficiently rely on the random access to locate the element.
272  bool needIdxReduc =
273  enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
274  // If then current tensor being inspected requires affine index, it need
275  // to be sliced.
276  for (Level l = 0; l < lvlRank; l++) {
277  const AffineExpr a = map.getResult(l);
278  const LevelType lt = enc.getLvlType(l);
279  if (idxReducBased && needIdxReduc) {
280  if (!findDepIdxSet(env.merger(), tid, l, a, lt))
281  return false; // inadmissible affine expression
282  } else {
283  if (!findAffine(env.merger(), tid, l, a, lt))
284  return false; // inadmissible affine expression
285  }
286  }
287  }
288  return annotated;
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // Sparsifier synthesis methods (statements and expressions).
293 //===----------------------------------------------------------------------===//
294 
295 /// Local bufferization of all dense and sparse data structures.
296 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
297  linalg::GenericOp op = env.op();
298  Location loc = op.getLoc();
299  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
300 
301  SmallVector<Range, 4> loopRange =
302  llvm::cast<linalg::LinalgOp>(op.getOperation())
303  .createLoopRanges(builder, loc);
304 
306  builder, loc,
307  /// Generates buffer for the output tensor.
308  /// Note that all sparse kernels assume that when all elements are written
309  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
310  /// to all zeroes and only nonzeroes values are computed and written out.
311  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
312  /// for the updates and no assumption on the original contents of the
313  /// output buffer is necessary.
314  [&op](OpBuilder &builder, Location loc, Value memref,
315  Value tensor) -> Value {
316  // Must not be a sparse tensor.
317  assert(!getSparseTensorEncoding(tensor.getType()));
318  // Two output tensor references should point to the same object.
319  OpOperand *lhs = op.getDpsInitOperand(0);
320  assert(lhs->get() == tensor);
321  // An output tensor can simply materialize from the buffer of the tensor
322  // that appears in the outs() clause. For updates, this has the
323  // advantage that only the nonzero value are involved in the
324  // computation, keeping the operation O(nnz). In all other cases, we are
325  // forced to zero out the buffer to enforce the assumption above, which
326  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
327  // O(nnz) for matrices).
328  // TODO: use better analysis to avoid zeroing out the buffer?
329  bool isInit = op.isInitTensor(lhs);
330  Value init = memref;
331  if (!isInit) {
332  Value zero = constantZero(builder, loc,
333  getElementTypeOrSelf(tensor.getType()));
334  builder.create<linalg::FillOp>(loc, ValueRange{zero},
335  ValueRange{init});
336  }
337  return init;
338  },
339  [&loopRange](OpBuilder &b, Location loc, Level l) {
340  assert(l < loopRange.size());
341  return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
342  });
343 }
344 
345 /// Generates index for load/store on sparse tensor.
346 // FIXME: It's not entirely clear what "index" means here (i.e., is it
347 // a "coordinate", or "Ldx", or what). So the function should be renamed
348 // and/or the documentation expanded in order to clarify.
349 static Value genIndex(CodegenEnv &env, OpOperand *t) {
350  const auto map = env.op().getMatchingIndexingMap(t);
351  const auto stt = getSparseTensorType(t->get());
352  const Level lvlRank = stt.getLvlRank();
353  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
354  const AffineExpr a = map.getResult(lvlRank - 1);
355  assert(a.getKind() == AffineExprKind::DimId);
356  const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
357  return env.getLoopVar(idx);
358 }
359 
360 /// Generates subscript for load/store on a dense or sparse tensor.
361 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
362  SmallVectorImpl<Value> &args) {
363  const Location loc = env.op().getLoc();
364  const TensorId tid = env.makeTensorId(t->getOperandNumber());
365  const auto map = env.op().getMatchingIndexingMap(t);
366  const auto stt = getSparseTensorType(t->get());
367  if (stt.hasEncoding()) {
368  // For sparse tensors we only push the last-level's position onto `args`.
369  const auto pos = env.emitter().getPosits()[tid].back();
370  assert(pos);
371  args.push_back(pos);
372  } else {
373  // For dense tensors we push all level's coordinates onto `args`.
374  const Level lvlRank = stt.getLvlRank();
375  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
376  for (Level l = 0; l < lvlRank; l++) {
377  const auto lvlExpr = map.getResult(l);
378  const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
379  args.push_back(lvlCrd);
380  }
381  }
382  return env.emitter().getValBuffer()[tid];
383 }
384 
385 /// Generates insertion code to implement dynamic tensor load.
387  OpOperand *t) {
388  linalg::GenericOp op = env.op();
389  Location loc = op.getLoc();
390  // Direct lexicographic coordinate order, tensor loads as zero.
391  if (!env.isExpand()) {
392  Type tp = getElementTypeOrSelf(t->get().getType());
393  return constantZero(builder, loc, tp);
394  }
395  // Load from expanded access pattern.
396  Value index = genIndex(env, t);
397  return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
398 }
399 
400 /// Generates insertion code to implement dynamic tensor load for reduction.
402  OpOperand *t) {
403  linalg::GenericOp op = env.op();
404  Location loc = op.getLoc();
405  Value identity = env.getCustomRedId();
406  // Direct lexicographic coordinate order, tensor loads as identity.
407  if (!env.isExpand())
408  return identity;
409  // Load from expanded access pattern if filled, identity otherwise.
410  Value values = env.getExpandValues();
411  Value filled = env.getExpandFilled();
412  Value index = genIndex(env, t);
413  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
414  Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
415  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
416 }
417 
418 /// Generates insertion code to implement dynamic tensor store.
419 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
420  Value rhs) {
421  linalg::GenericOp op = env.op();
422  Location loc = op.getLoc();
423  // Direct insertion in lexicographic coordinate order.
424  if (!env.isExpand()) {
425  const LoopOrd numLoops = op.getRank(t);
426  // Retrieves the first `numLoop` induction variables.
427  SmallVector<Value> ivs = llvm::to_vector(
428  llvm::drop_end(env.emitter().getLoopIVsRange(),
429  env.emitter().getCurrentDepth() - numLoops));
430  Value chain = env.getInsertionChain();
431  if (!env.getValidLexInsert()) {
432  env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
433  } else {
434  // Generates runtime check for a valid lex during reduction,
435  // to avoid inserting the identity value for empty reductions.
436  // if (validLexInsert) then
437  // insert(rhs) into chain
438  // return updated chain
439  // else
440  // return unmodified chain
441  scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
442  loc, chain.getType(), env.getValidLexInsert(),
443  /*else=*/true);
444  // True branch.
445  builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
446  Value res = builder.create<InsertOp>(loc, rhs, chain, ivs);
447  builder.create<scf::YieldOp>(loc, res);
448  // False branch.
449  builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
450  builder.create<scf::YieldOp>(loc, chain);
451  // Value assignment.
452  builder.setInsertionPointAfter(ifValidLexInsert);
453  env.updateInsertionChain(ifValidLexInsert.getResult(0));
454  }
455  return;
456  }
457  // Generates insertion code along expanded access pattern.
458  // if (!expFilled[i]) then
459  // expFilled[i] = true
460  // expAdded[inserts++] = i
461  // endif
462  // values[i] = rhs
463  Value values = env.getExpandValues();
464  Value filled = env.getExpandFilled();
465  Value added = env.getExpandAdded();
466  Value count = env.getExpandCount();
467  Value index = genIndex(env, t);
468  Value fval = constantI1(builder, loc, false);
469  Value tval = constantI1(builder, loc, true);
470  // If statement.
471  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
472  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
473  isFilled, fval);
474  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
475  /*else=*/true);
476  // True branch.
477  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
478  builder.create<memref::StoreOp>(loc, tval, filled, index);
479  builder.create<memref::StoreOp>(loc, index, added, count);
480  Value one = constantIndex(builder, loc, 1);
481  Value add = builder.create<arith::AddIOp>(loc, count, one);
482  builder.create<scf::YieldOp>(loc, add);
483  // False branch.
484  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
485  builder.create<scf::YieldOp>(loc, count);
486  builder.setInsertionPointAfter(ifOp);
487  // Value assignment.
488  env.updateExpandCount(ifOp.getResult(0));
489  builder.create<memref::StoreOp>(loc, rhs, values, index);
490 }
491 
492 /// Generates a load on a dense or sparse tensor.
493 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
494  // Test if the load was hoisted to a higher loop nest.
495  Value val = env.exp(exp).val;
496  if (val)
497  return val;
498 
499  // Load during insertion.
500  linalg::GenericOp op = env.op();
501  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
502  if (env.isSparseOutput(t)) {
503  if (env.isCustomReduc())
504  return genInsertionLoadReduce(env, builder, t);
505  return genInsertionLoad(env, builder, t);
506  }
507  // Actual load.
508  SmallVector<Value> args;
509  Value ptr = genSubscript(env, builder, t, args);
510  return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
511 }
512 
513 /// Generates a store on a dense or sparse tensor.
514 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
515  Value rhs) {
516  // Only unary and binary are allowed to return an uninitialized rhs
517  // to indicate missing output. Or otherwise a custom reduction that
518  // received no value to accumulate.
519  if (!rhs) {
520  assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
521  env.exp(exp).kind == TensorExp::Kind::kBinary ||
522  env.exp(exp).kind == TensorExp::Kind::kReduce);
523  return;
524  }
525  // Test if this is a scalarized reduction.
526  if (env.isReduc()) {
527  env.updateReduc(rhs);
528  return;
529  }
530  // Regular store.
531  linalg::GenericOp op = env.op();
532  Location loc = op.getLoc();
533  OpOperand *t = op.getDpsInitOperand(0);
534  if (!env.isSparseOutput(t)) {
535  SmallVector<Value> args;
536  Value ptr = genSubscript(env, builder, t, args);
537  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
538  return;
539  }
540  // Store during sparse insertion.
541  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
542  genInsertionStore(env, builder, t, rhs);
543  return;
544  }
545  // Select operation insertion.
546  Value chain = env.getInsertionChain();
547  scf::IfOp ifOp =
548  builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
549  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
550  // Existing value was preserved to be used here.
551  assert(env.exp(exp).val);
552  Value v0 = env.exp(exp).val;
553  genInsertionStore(env, builder, t, v0);
554  env.merger().clearExprValue(exp);
555  // Yield modified insertion chain along true branch.
556  Value mchain = env.getInsertionChain();
557  builder.create<scf::YieldOp>(op.getLoc(), mchain);
558  // Yield original insertion chain along false branch.
559  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
560  builder.create<scf::YieldOp>(loc, chain);
561  // Done with if statement.
562  env.updateInsertionChain(ifOp->getResult(0));
563  builder.setInsertionPointAfter(ifOp);
564 }
565 
566 /// Generates an invariant value.
567 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
568  return env.exp(exp).val;
569 }
570 
571 /// Semi-ring branches are simply inlined by the sparsifier. Prior
572 /// analysis has verified that all computations are "local" to the inlined
573 /// branch or otherwise invariantly defined outside the loop nest, with the
574 /// exception of index computations, which need to be relinked to actual
575 /// inlined cloned code.
576 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
577  Value e, LoopId ldx) {
578  if (auto arg = dyn_cast<BlockArgument>(e)) {
579  // Direct arguments of the original linalg op must be converted
580  // into dense tensor loads. Note that we should not encounter
581  // anything else. This needs to be verified by semi-ring ops.
582  linalg::GenericOp op = env.op();
583  if (arg.getOwner()->getParentOp() == op) {
584  const TensorId tid = env.makeTensorId(arg.getArgNumber());
585  OpOperand *t = &op->getOpOperand(tid);
586  assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
587  SmallVector<Value> args;
588  Value ptr = genSubscript(env, rewriter, t, args);
589  return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
590  }
591  } else if (Operation *def = e.getDefiningOp()) {
592  // Handle index computation.
593  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
594  return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
595  // When still defined in new body, recurse into operands.
596  if (def->getBlock() == block) {
597  rewriter.setInsertionPoint(def);
598  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
599  rewriter.updateRootInPlace(def, [&]() {
600  def->setOperand(
601  i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
602  });
603  }
604  }
605  }
606  return e;
607 }
608 
609 /// Recursively generates tensor expression.
610 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
611  LoopId ldx) {
613  return Value();
614 
615  linalg::GenericOp op = env.op();
616  Location loc = op.getLoc();
617  const TensorExp &exp = env.exp(e);
618  const auto kind = exp.kind;
619  if (kind == TensorExp::Kind::kTensor)
620  return genTensorLoad(env, rewriter, e);
621  if (kind == TensorExp::Kind::kInvariant)
622  return genInvariantValue(env, e);
623  if (kind == TensorExp::Kind::kLoopVar)
624  return env.getLoopVar(exp.loop);
625 
626  if (kind == TensorExp::Kind::kReduce)
627  env.startCustomReduc(e); // enter custom
628 
629  Value v0, v1;
630  // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
631  // based on the type of the other operand.
634  v1 = genExp(env, rewriter, exp.children.e1, ldx);
635  v0 = constantZero(rewriter, loc, v1.getType());
638  v0 = genExp(env, rewriter, exp.children.e0, ldx);
639  v1 = constantZero(rewriter, loc, v0.getType());
640  } else {
641  v0 = genExp(env, rewriter, exp.children.e0, ldx);
642  v1 = genExp(env, rewriter, exp.children.e1, ldx);
643  }
644 
645  Value ee;
646  if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
647  // custom reduce did not receive a value
648  } else {
649  ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
650  if (ee &&
653  kind == TensorExp::Kind::kReduce ||
654  kind == TensorExp::Kind::kSelect)) {
655  OpBuilder::InsertionGuard guard(rewriter);
656  ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
657  }
658  }
659 
660  if (kind == TensorExp::Kind::kReduce)
661  env.endCustomReduc(); // exit custom
662 
663  if (kind == TensorExp::Kind::kSelect)
664  env.merger().setExprValue(e, v0); // Preserve value for later use.
665 
666  return ee;
667 }
668 
669 /// Hoists loop invariant tensor loads for which indices have been exhausted.
670 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
671  LoopId ldx, bool atStart) {
673  return;
674  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
675  // Inspect tensor indices.
676  bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
677  linalg::GenericOp op = env.op();
678  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
679  const auto map = op.getMatchingIndexingMap(&t);
680  const auto stt = getSparseTensorType(t.get());
681  const Level lvlRank = stt.getLvlRank();
682  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
683  for (Level l = 0; l < lvlRank; l++) {
684  const AffineExpr a = map.getResult(l);
685  if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
686  return; // still in play
687  }
688  // All exhausted at this level (isAtLoop denotes exactly at this LoopId).
689  if (!isAtLoop)
690  return;
691  OpOperand *lhs = op.getDpsInitOperand(0);
692  if (lhs == &t) {
693  // Start or end a scalarized reduction.
694  if (atStart) {
695  Value load = env.isCustomReduc() ? env.getCustomRedId()
696  : genTensorLoad(env, builder, exp);
697  env.startReduc(exp, load);
698  if (env.hasSparseOutput())
699  env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
700  } else {
701  genTensorStore(env, builder, exp, env.endReduc());
702  env.clearValidLexInsert();
703  }
704  } else {
705  // Start or end loop invariant hoisting of a tensor load.
706  if (atStart)
707  env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
708  else
709  env.merger().clearExprValue(exp);
710  }
711  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
712  env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
713  env.exp(exp).kind != TensorExp::Kind::kSynZero) {
714  // Traverse into the binary operations. Note that we only hoist
715  // tensor loads, since subsequent MLIR/LLVM passes know how to
716  // deal with all other kinds of derived loop invariants.
717  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
718  env.startCustomReduc(exp); // enter custom
719  const ExprId e0 = env.exp(exp).children.e0;
720  const ExprId e1 = env.exp(exp).children.e1;
721  genInvariants(env, builder, e0, ldx, atStart);
722  genInvariants(env, builder, e1, ldx, atStart);
723  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
724  env.endCustomReduc(); // exit custom
725  }
726 }
727 
728 /// Generates an expanded access pattern in innermost dimension.
729 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
730  bool atStart) {
731  linalg::GenericOp op = env.op();
732  OpOperand *lhs = op.getDpsInitOperand(0);
733  if (!env.atExpandLevel(lhs, op.getRank(lhs), at))
734  return; // not needed at this level
735  assert(!env.isReduc());
736  // Generate start or end of an expanded access pattern. Note that because
737  // an expansion does not rely on the ongoing contents of the sparse storage
738  // scheme, we can use the original tensor as incoming SSA value (which
739  // simplifies codegen a bit). If expansion on the actual contents is ever
740  // needed, we will need to use the SSA value in the insertion chain instead.
741  Value tensor = lhs->get();
742  Location loc = op.getLoc();
743  if (atStart) {
744  auto dynShape = {ShapedType::kDynamic};
745  Type etp = cast<ShapedType>(tensor.getType()).getElementType();
746  Type t1 = MemRefType::get(dynShape, etp);
747  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
748  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
749  Type t4 = builder.getIndexType();
750  auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
751  assert(r.getNumResults() == 4);
752  env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
753  r.getResult(3));
754  } else {
755  SmallVector<Value> indices;
756  for (LoopOrd i = 0; i < at; i++)
757  indices.push_back(env.emitter().getLoopIV(i));
758  Value values = env.getExpandValues();
759  Value filled = env.getExpandFilled();
760  Value added = env.getExpandAdded();
761  Value count = env.getExpandCount();
762  Value chain = env.getInsertionChain();
763  Value compress = builder.create<CompressOp>(loc, values, filled, added,
764  count, chain, indices);
765  env.updateInsertionChain(compress);
766  env.endExpand();
767  }
768 }
769 
770 /// Returns parallelization strategy. Any implicit loop in the Linalg
771 /// operation that is marked "parallel" is a candidate. Whether it is actually
772 /// converted to a parallel operation depends on the requested strategy.
773 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
774  // Reject parallelization of sparse output.
775  if (env.hasSparseOutput())
776  return false;
777  // Parallel loops on tensor expansion can cause data races.
778  if (env.isExpand())
779  return false;
780  // Inspect strategy.
781  switch (env.options().parallelizationStrategy) {
783  return false;
785  return isOuter && !isSparse;
787  return isOuter;
789  return !isSparse;
791  return true;
792  }
793  llvm_unreachable("unexpected parallelization strategy");
794 }
795 
796 /// Whether or not the current loop being generated should be parallized (if
797 /// possible) according to the configuration.
798 static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
799  ArrayRef<TensorLevel> tidLvls) {
800  linalg::GenericOp op = env.op();
801  auto iteratorTypes = op.getIteratorTypesArray();
802  bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
803  // Queries the LT based on the tensor id and loop idx, as requested by
804  // `CodegenEnv::lt(TensorId, LoopIdx)`. The returned LT from CodegenEnv
805  // should be consistent with the LT indexed by <TensorId, Level>.
806  const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
807  return isCompressedLT(lt) || isSingletonLT(lt);
808  });
809 
810  return isParallelFor(env, isOuter, isSparse);
811 }
812 
813 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
814 /// can either be a for loop or while loop depending on whether there is at most
815 /// one sparse level in the list.
817  LoopId idx, ArrayRef<TensorLevel> tidLvls,
818  bool tryParallel, bool needsUniv) {
819  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
820  // Construct the while-loop with a parameter for each
821  // index.
822  return env.emitter().enterCoIterationOverTensorsAtLvls(
823  builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
824  /*genDedup=*/true, needsUniv);
825  });
826  assert(loop);
827  return loop;
828 }
829 
830 /// Generates a for-loop or a while-loop, depending on whether it implements
831 /// singleton iteration or co-iteration over the given conjunction.
832 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
833  bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
834  bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
835  return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
836 }
837 
838 /// Generates the induction structure for a while-loop.
839 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
840  bool needsUniv) {
841  Location loc = env.op().getLoc();
842  // Finalize each else branch of all if statements.
843  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
844  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
845  builder.getInsertionBlock()->getParentOp())) {
846  // Break on IfOp for slicing filtering.
847  if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
848  StringAttr::get(ifOp->getContext(), "slice"))
849  break;
850 
851  unsigned y = 0;
852  SmallVector<Value> yields;
853  if (env.isReduc()) {
854  yields.push_back(env.getReduc());
855  env.updateReduc(ifOp.getResult(y++));
856  if (env.getValidLexInsert()) {
857  yields.push_back(env.getValidLexInsert());
858  env.setValidLexInsert(ifOp.getResult(y++));
859  }
860  }
861  if (env.isExpand()) {
862  yields.push_back(env.getExpandCount());
863  env.updateExpandCount(ifOp->getResult(y++));
864  }
865  if (env.getInsertionChain()) {
866  yields.push_back(env.getInsertionChain());
867  env.updateInsertionChain(ifOp->getResult(y++));
868  }
869  assert(y == yields.size());
870  builder.create<scf::YieldOp>(loc, yields);
871  builder.setInsertionPointAfter(ifOp);
872  }
873  }
874  // No need to set the insertion point here as LoopEmitter keeps track of the
875  // basic block where scf::Yield should be inserted.
876 }
877 
878 /// Generates a single if-statement within a while-loop.
879 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
880  LatPointId p) {
881  Location loc = env.op().getLoc();
882  SmallVector<Type> types;
883  Value cond;
885  p, /*simple=*/true,
886  [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
887  bool isIdxRed) {
888  if (isIdxRed) {
889  // Since there is no 1:1 mapping from loop to level (multiple loops
890  // are required to resolve one level with non-trivial index
891  // expression), we need to reconstruct the tensor level types if this
892  // loop requires index reduction condition.
893  assert(lvl.has_value() && isUndefLT(lt));
894  auto stt = getSparseTensorType(env.op().getInputs()[tid]);
895  lt = stt.getLvlType(*lvl);
896  }
897  assert(ldx == env.merger().loop(b));
898  Value clause;
899  if (isCompressedLT(lt) || isSingletonLT(lt) ||
900  isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
901  assert(lvl.has_value());
902  const Value crd = env.emitter().getCoords()[tid][*lvl];
903  const Value lvar = env.getLoopVar(ldx);
904  clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
905  crd, lvar);
906  } else {
907  assert(isDenseLT(lt) || isUndefLT(lt));
908  clause = constantI1(builder, loc, true);
909  }
910  cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
911  });
912  if (env.isReduc()) {
913  types.push_back(env.getReduc().getType());
914  if (env.getValidLexInsert())
915  types.push_back(env.getValidLexInsert().getType());
916  }
917  if (env.isExpand())
918  types.push_back(builder.getIndexType());
919  if (env.getInsertionChain())
920  types.push_back(env.getInsertionChain().getType());
921  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
922  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
923  return ifOp;
924 }
925 
926 /// Generates end of true branch of if-statement within a while-loop.
927 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
928  Value redInput, Value cntInput, Value insInput,
929  Value validIns) {
930  SmallVector<Value> operands;
931  if (env.isReduc()) {
932  operands.push_back(env.getReduc());
933  env.updateReduc(redInput);
934  if (env.getValidLexInsert()) {
935  // Any overlapping indices during a reduction creates a valid lex insert.
936  operands.push_back(constantI1(builder, env.op().getLoc(), true));
937  env.setValidLexInsert(validIns);
938  }
939  }
940  if (env.isExpand()) {
941  operands.push_back(env.getExpandCount());
942  env.updateExpandCount(cntInput);
943  }
944  if (env.getInsertionChain()) {
945  operands.push_back(env.getInsertionChain());
946  env.updateInsertionChain(insInput);
947  }
948  if (!operands.empty())
949  builder.create<scf::YieldOp>(env.op().getLoc(), operands);
950  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // Sparsifier synthesis methods (loop sequence).
955 //===----------------------------------------------------------------------===//
956 
957 /// Starts a loop sequence at given level. Returns true if
958 /// the universal loop index must be maintained at this level.
959 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
960  LoopOrd idx, LoopId ldx, LatSetId lts) {
961  assert(!env.getLoopVar(idx));
962  // Emit invariants at this loop sequence level.
963  genInvariants(env, builder, exp, ldx, /*atStart=*/true);
964  // Emit access pattern expansion for sparse tensor output.
965  genExpand(env, builder, idx, /*atStart=*/true);
966  // Emit further intitialization at this loop sequence level.
967  const LatPointId l0 = env.set(lts)[0];
968  bool needsUniv = false;
969 
970  SmallVector<TensorLevel> tidLvls;
971  env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
972  std::optional<Level> lvl,
973  LevelType lt, bool isIdxReduc) {
974  assert(env.merger().loop(b) == idx);
975  if (isDenseLT(lt) || isUndefLT(lt)) {
976  if (tid == env.merger().getSynTensorID()) {
977  // Needs loop emitter to set up loop bounds for synthetic tensor too if
978  // there is a loop condition imposed on the synthetic tensor.
979  tidLvls.push_back(
980  env.makeTensorLevel(tid, env.emitter().getCurrentDepth()));
981  }
982  needsUniv = true;
983  }
984  if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
985  is2OutOf4LT(lt) || isIdxReduc) {
986  // Only when this is a index reduction loop, can the lt be undefined.
987  assert(!isUndefLT(lt) || isIdxReduc);
988  // sparse/singleton levels, or a dense/sparse index reduction loop.
989  tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
990  }
991  });
992 
993  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
994 
995  // Maintain the universal index only if it is actually
996  // consumed by a subsequent lattice point.
997  if (needsUniv) {
998  for (const LatPointId li : env.set(lts).drop_front())
999  if (!env.merger().hasAnySparse(env.lat(li).simple))
1000  return true;
1001  }
1002  return false;
1003 }
1004 
1006  OpBuilder &builder, TensorId tid,
1007  Level startLvl) {
1008  // TODO: Handle affine expression on output tensor.
1009  linalg::GenericOp op = env.op();
1010  assert(tid < op.getNumDpsInputs());
1011  OpOperand *input = op.getDpsInputOperands()[tid];
1012  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1013  const auto enc = getSparseTensorEncoding(input->get().getType());
1014  if (enc) {
1015  const Location loc = op.getLoc();
1016  const TensorId tid = env.makeTensorId(input->getOperandNumber());
1017  const Level lvlRank = enc.getLvlRank();
1018  assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1019  for (Level l = startLvl; l < lvlRank; l++) {
1020  AffineExpr lvlExpr = lvlExprs[l];
1021  if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
1023  builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1024  else
1025  return; // break on first non-dense non-constant level
1026  }
1027  }
1028 }
1029 
1031  RewriterBase &rewriter) {
1032  // We can generate address for constant affine expression before any loops
1033  // starting from the first level as they do not depend on any thing.
1034  // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1035  // levels can be determined before loops.
1036  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1037  genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1038 }
1039 
1040 /// Return true if the lattices bit can be iterated by a for loop.
1042  CodegenEnv &env, LatPointId li, LoopId ldx,
1044  SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1045  const BitVector &simple = env.lat(li).simple;
1046  const TensorId outTid = env.merger().getOutTensorID();
1047  const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
1048 
1049  unsigned numloopCond = 0;
1050  bool hasNonUnique = false;
1051  env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
1052  std::optional<Level> lvl,
1053  LevelType lt, bool isIdxReduc) {
1054  if (simple[b]) {
1055  if (isIdxReduc) {
1056  tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
1057  numloopCond++;
1058  return;
1059  }
1060  if (isUndefLT(lt)) {
1061  // An undefined lt in the lattices, we probably mean to
1062  // generate a dense loop according to the synthetic tensor (for
1063  // invariants and sparse output tensor).
1064  if (env.merger().getSynTensorID() == tid) {
1065  // Coiterating with an invariant
1066  // e.g., out = prod(in[i][j] op invariant);
1067  // or a broadcast
1068  // e.g., out[i][j] = in[i] (j is undef for input)
1069  //
1070  // The level of the synthetic tensor is the current loop depth;
1071  // the rank of the synthetic tensor equals to number of loops.
1072  lvl = env.emitter().getCurrentDepth();
1073  } else if (!lvl) {
1074  // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1075  return;
1076  }
1077  }
1078  hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1079  tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
1080  numloopCond++;
1081  } else if (isDenseLT(lt) || isIdxReduc) {
1082  tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
1083  } else {
1084  assert(isUndefLT(lt));
1085  linalg::GenericOp op = env.op();
1086  if (tid >= op.getNumDpsInputs())
1087  // We only handle affine expression on input tensors (for now).
1088  return;
1089  OpOperand *operand = &op->getOpOperand(tid);
1090  const auto stt = getSparseTensorType(operand->get());
1091  // Non-annotated dense tensors requires no special handling.
1092  if (!stt.hasEncoding())
1093  return;
1094 
1095  ArrayRef<AffineExpr> affines =
1096  op.getMatchingIndexingMap(operand).getResults();
1097  const Level lvlRank = stt.getLvlRank();
1098  assert(affines.size() == static_cast<size_t>(lvlRank));
1099  for (Level l = 0; l < lvlRank; l++) {
1100  AffineExpr exp = affines[l];
1101  // Skip simple affine expression and non-dense levels (which
1102  // have their own filter loop).
1103  if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
1104  continue;
1105 
1106  // Constant affine expression are handled in genLoop
1107  if (!isa<AffineConstantExpr>(exp)) {
1108  bool isAtLoop = false;
1109  if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
1110  isAtLoop) {
1111  // If the compound affine is invariant and we are right at the
1112  // level. We need to generate the address according to the
1113  // affine expression. This is also the best place we can do it
1114  // to avoid putting it inside inner loops.
1115  // NOTE: It assumes that the levels of the input tensor are
1116  // initialized in order (and it is also currently guaranteed by
1117  // computeIterationGraph), another more admissible approach
1118  // might be accepting out-of-order access between consecutive
1119  // dense levels.
1120  affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
1121  }
1122  }
1123  }
1124  }
1125  });
1126 
1127  if (isDenseLT(env.lt(outTid, ldx))) {
1128  // Note that we generate dense indices of the output tensor
1129  // unconditionally, since they may not appear in the lattice, but may be
1130  // needed for linearized env.
1131  tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
1132  }
1133 
1134  if (numloopCond == 0) {
1135  // Corner cases where the loop bound is defined by a *unused* operand, in
1136  // this case, we just generate a dense "fake" loop by iterating over the
1137  // synthetic tensor.
1138  tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
1139  env.emitter().getCurrentDepth()));
1140  numloopCond++;
1141  }
1142  // If we just need to one loop conditions and the conditions is not imposed on
1143  // non-unique level, the loop can be generated by a for loop.
1144  return numloopCond == 1 && !hasNonUnique;
1145 }
1146 
1147 /// Starts a single loop in current sequence.
1148 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1149  OpBuilder &builder, LoopOrd at,
1150  LatPointId li, bool needsUniv) {
1151  // The set of tensors + lvls to generate loops on
1152  SmallVector<TensorLevel> tidLvls;
1153  // The set of dense tensors with non-trivial affine expression that just
1154  // becomes invariant and the address shall now be generated at the current
1155  // level.
1157  bool isSingleCond =
1158  translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
1159 
1160  // Emit the for/while-loop control.
1161  Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls);
1162  Location loc = env.op().getLoc();
1163  for (auto [tidLvl, exp] : affineTidLvls) {
1164  env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
1165  }
1166 
1167  // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1168  // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1169  // on constant affines expression may now be determined.
1170  auto allTidLvls =
1171  llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1172  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1173  if (tid != env.merger().getOutTensorID() &&
1174  tid != env.merger().getSynTensorID())
1175  genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1176  }
1177 
1178  return std::make_pair(loop, isSingleCond);
1179 }
1180 
1181 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1182 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1183  LoopId idx, LatPointId li, bool needsUniv,
1184  bool isSingleCond) {
1185 
1186  if (isSingleCond) {
1187  // Either a for-loop or a while-loop that iterates over a slice.
1188  // Any iteration creates a valid lex insert.
1189  if (env.isReduc() && env.getValidLexInsert())
1190  env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1191  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1192  // End a while-loop.
1193  finalizeWhileOp(env, rewriter, idx, needsUniv);
1194  } else {
1195  needsUniv = false;
1196  }
1197 
1198  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1199  env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1200  return std::nullopt;
1201  });
1202 
1203  return needsUniv;
1204 }
1205 
1206 /// Ends a loop sequence at given level.
1207 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1208  unsigned idx, unsigned ldx) {
1209  assert(!env.getLoopVar(idx));
1210  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1211  // Unmark bookkeeping of invariants and loop index.
1212  genInvariants(env, builder, exp, ldx, /*atStart=*/false);
1213  // Finalize access pattern expansion for sparse tensor output.
1214  genExpand(env, builder, idx, /*atStart=*/false);
1215 }
1216 
1217 /// Recursively generates code while computing iteration lattices in order
1218 /// to manage the complexity of implementing co-iteration over unions
1219 /// and intersections of sparse iterations spaces.
1220 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1221  LoopOrd at) {
1222  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1223  if (at == env.getLoopNum()) {
1224  Value rhs = genExp(env, rewriter, exp, at - 1);
1225  genTensorStore(env, rewriter, exp, rhs);
1226  return;
1227  }
1228 
1229  // Construct iteration lattices for current loop index, with L0 at top.
1230  const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
1231  const LatSetId lts =
1232  env.merger().optimizeSet(env.merger().buildLattices(exp, at));
1233 
1234  // Start a loop sequence.
1235  bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
1236 
1237  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1238  //
1239  // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
1240  // because the loop body causes data-movement which invalidates
1241  // the iterator.
1242  const unsigned lsize = env.set(lts).size();
1243  for (unsigned i = 0; i < lsize; i++) {
1244  const LatPointId li = env.set(lts)[i];
1245  // Start a loop.
1246  auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv);
1247 
1248  // Visit all lattices points with Li >= Lj to generate the
1249  // loop-body, possibly with if statements for coiteration.
1250  Value redInput = env.getReduc();
1251  Value cntInput = env.getExpandCount();
1252  Value insInput = env.getInsertionChain();
1253  Value validIns = env.getValidLexInsert();
1254  // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
1255  // because the loop body causes data-movement which invalidates the
1256  // iterator.
1257  for (unsigned j = 0; j < lsize; j++) {
1258  const LatPointId lj = env.set(lts)[j];
1259  const ExprId ej = env.lat(lj).exp;
1260  if (li == lj || env.merger().latGT(li, lj)) {
1261  // Recurse into body of each branch.
1262  if (!isSingleCond) {
1263  scf::IfOp ifOp = genIf(env, rewriter, at, lj);
1264  genStmt(env, rewriter, ej, at + 1);
1265  endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1266  } else {
1267  genStmt(env, rewriter, ej, at + 1);
1268  }
1269  }
1270  }
1271 
1272  // End a loop.
1273  needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
1274  }
1275 
1276  // End a loop sequence.
1277  endLoopSeq(env, rewriter, exp, at, ldx);
1278 }
1279 
1280 /// Converts the result computed by the sparse kernel into the required form.
1281 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1282  linalg::GenericOp op = env.op();
1283  OpOperand *lhs = op.getDpsInitOperand(0);
1284  Value tensor = lhs->get();
1285  Type resType = tensor.getType();
1286  if (getSparseTensorEncoding(resType)) {
1287  // The sparse tensor rematerializes from the original sparse tensor's
1288  // underlying sparse storage format. For an insertion chain, the
1289  // tensor materializes from the chain with 'hasInserts' enabled.
1290  bool hasInserts = false;
1291  if (Value chain = env.getInsertionChain()) {
1292  hasInserts = true;
1293  tensor = chain;
1294  }
1295  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1296  } else {
1297  // To rematerialize an non-annotated tensor, simply load it
1298  // from the bufferized value.
1299  Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1300  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1301  }
1302 }
1303 
1304 //===----------------------------------------------------------------------===//
1305 // Sparsifier rewriting methods.
1306 //===----------------------------------------------------------------------===//
1307 
1308 namespace {
1309 
1310 /// Sparse rewriting rule for generic Lingalg operation.
1311 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1312 public:
1313  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1314  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1315 
1316  LogicalResult matchAndRewrite(linalg::GenericOp op,
1317  PatternRewriter &rewriter) const override {
1318  // Only accept single output operations with pure tensor semantics.
1319  if (op.getNumDpsInits() != 1 || !op.hasTensorSemantics())
1320  return failure();
1321 
1322  // Only accept trivial affine indices.
1324  return failure();
1325 
1326  if (!op->hasAttr("sorted")) {
1327  return rewriter.notifyMatchFailure(
1328  op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1329  "before sparsification.");
1330  }
1331  // Must have been demapped as well if the generic op is sorted.
1333 
1334  // Sets up a code generation environment.
1335  const unsigned numTensors = op->getNumOperands();
1336  const unsigned numLoops = op.getNumLoops();
1337  bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1338  // If we have indexing map like (d0) -> (0, d0), there might be more
1339  // levels then loops because of the constant index, that means we can not
1340  // use numLoops as the upper bound for ranks of all tensors.
1341  // TODO: Constant indices are currently not support on sparse tensor, but
1342  // are allowed in non-annotated dense tensor. Support it, it would be
1343  // required for sparse tensor slice rank reducing too.
1344  Level maxLvlRank = 0;
1345  for (auto operand : op.getOperands()) {
1346  if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1347  maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1348  }
1349  }
1350 
1351  CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1352  // Detects sparse annotations and translates the per-level sparsity
1353  // information for all tensors to loop indices in the kernel.
1354  if (!findSparseAnnotations(env, needIdxRed))
1355  return failure();
1356 
1357  // Only standard reduction operations (add, sub, or, xor) that can be
1358  // sparsified by merely reducing the stored values are admissible. More
1359  // elaborate reduction operations (such as mul, and, min, max) would need
1360  // to know whether implicit zeros occur as well. They can still be
1361  // implemented with a custom reduction operation, accepted here as well.
1362  if (op.getNumReductionLoops() > 0) {
1363  Operation *yield = op.getRegion().front().getTerminator();
1364  assert(isa<linalg::YieldOp>(yield));
1365  Operation *redop = yield->getOperand(0).getDefiningOp();
1366  if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1367  !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1368  !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1369  !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1370  !isa<ReduceOp>(redop)) {
1371  return failure();
1372  }
1373  }
1374 
1375  // Constructs the tensor expressions tree from `op`, returns failure if the
1376  // tree can not be built or the tensor expression is inadmissible.
1377  if (failed(env.initTensorExp()))
1378  return failure();
1379 
1380  // Recursively generates code if admissible.
1381  env.startEmit();
1382  genBuffers(env, rewriter);
1383  // TODO: Constant affine expression should be handled differently when using
1384  // slice-based codegen, it does not matter now because we already reject the
1385  // constant expression at a earlier stage.
1386  genInitConstantDenseAddress(env, rewriter);
1387  genStmt(env, rewriter, env.getExprId(), 0);
1388  genResult(env, rewriter);
1389  return success();
1390  }
1391 
1392 private:
1393  /// Options to control sparse code generation.
1395 };
1396 
1397 } // namespace
1398 
1399 /// Populates the given patterns list with rewriting rules required for
1400 /// the sparsification of linear algebra operations.
1402  RewritePatternSet &patterns, const SparsificationOptions &options) {
1403  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1404 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Get the total number of compound affine expressions in the getMatchingIndexingMap for the given tenso...
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e, LoopId ldx)
Semi-ring branches are simply inlined by the sparsifier.
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx, bool &isAtLoop)
Determines if affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned idx, unsigned ldx)
Ends a loop sequence at given level.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LoopId idx, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Return true if the lattices bit can be iterated by a for loop.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopOrd idx, LoopId ldx, LatSetId lts)
Starts a loop sequence at given level.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, LoopId ldx)
Recursively generates tensor expression.
static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at, bool atStart)
Generates an expanded access pattern in innermost dimension.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, bool needsUniv)
Generates the induction structure for a while-loop.
static Operation * genCoIteration(CodegenEnv &env, OpBuilder &builder, LoopId idx, ArrayRef< TensorLevel > tidLvls, bool tryParallel, bool needsUniv)
Emit a loop to coiterate over the list of tensor levels.
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId ldx, bool atStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, LatPointId li, bool needsUniv)
Starts a single loop in current sequence.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopOrd at)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:427
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:538
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:238
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:112
unsigned getLoopDepth() const
Definition: CodegenEnv.h:111
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:79
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:103
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:91
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:78
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:68
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:215
bool atExpandLevel(OpOperand *o, unsigned rank, LoopOrd n) const
Definition: CodegenEnv.cpp:210
unsigned getLoopNum() const
Definition: CodegenEnv.h:85
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:204
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:271
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the loop identified by the given LoopId.
Definition: CodegenEnv.cpp:196
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:107
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:77
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:224
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:131
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:71
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:80
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:282
const std::vector< Value > & getValBuffer() const
Definition: LoopEmitter.h:280
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
const std::vector< std::vector< Value > > & getPosits() const
Getters.
Definition: LoopEmitter.h:271
Value getLoopIV(LoopOrd n) const
Gets loop induction variable for the given LoopOrd.
Definition: LoopEmitter.h:210
auto getLoopIVsRange() const
Get the range of values for all induction variables.
Definition: LoopEmitter.h:191
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
LoopOrd getCurrentDepth() const
Gets the current depth of the loop-stack.
Definition: LoopEmitter.h:205
const std::vector< std::vector< Value > > & getCoords() const
Definition: LoopEmitter.h:272
A class to handle all iteration lattice operations.
Definition: Merger.h:224
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:415
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:936
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:248
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:425
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:441
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:429
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:471
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:669
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:366
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:502
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:347
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:362
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:398
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1538
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:480
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:339
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
constexpr bool isLooseCompressedLT(LevelType lt)
Check if the LevelType is loose compressed (regardless of properties).
Definition: Enums.h:271
unsigned TensorLevel
Definition: LoopEmitter.h:46
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
constexpr bool isUniqueLT(LevelType lt)
Check if the LevelType is unique (regardless of storage format).
Definition: Enums.h:299
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
constexpr bool isUndefLT(LevelType lt)
Check if the LevelType is the special undefined value.
Definition: Enums.h:250
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:386
constexpr bool is2OutOf4LT(LevelType lt)
Check if the LevelType is 2OutOf4 (regardless of properties).
Definition: Enums.h:277
constexpr bool isDenseLT(LevelType lt)
Check if the LevelType is dense (regardless of properties).
Definition: Enums.h:253
constexpr bool isSingletonLT(LevelType lt)
Check if the LevelType is singleton (regardless of properties).
Definition: Enums.h:265
LevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:168
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
unsigned LoopOrd
The position of a loop in the loop-stack, or the position of a LoopId in a topologically-sorted list ...
Definition: LoopEmitter.h:43
constexpr bool isCompressedLT(LevelType lt)
Check if the LevelType is compressed (regardless of properties).
Definition: Enums.h:259
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Options for the Sparsification pass.
Definition: Passes.h:76
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:81
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:217
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:214
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.