MLIR  20.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenEnv.h"
14 #include "Utils/CodegenUtils.h"
15 #include "Utils/LoopEmitter.h"
16 
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TensorEncoding.h"
36 #include "llvm/ADT/SmallBitVector.h"
37 
38 #include <optional>
39 
40 using namespace mlir;
41 using namespace mlir::sparse_tensor;
42 
43 //===----------------------------------------------------------------------===//
44 // Sparsifier analysis methods.
45 //===----------------------------------------------------------------------===//
46 
47 /// Returns true iff affine expression is invariant. Sets the
48 /// parameter `isCurrentLoop` when expression just became invariant.
49 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50  switch (a.getKind()) {
51  case AffineExprKind::DimId: {
52  const LoopId i = cast<AffineDimExpr>(a).getPosition();
53  if (i + 1 == curr) {
54  isCurrentLoop = true;
55  return true; // becomes invariant at current loop
56  }
57  return i < curr; // invariant when already generated
58  }
60  case AffineExprKind::Mul: {
61  auto binOp = cast<AffineBinaryOpExpr>(a);
62  return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
63  isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
64  }
65  default: {
66  assert(isa<AffineConstantExpr>(a));
67  return true;
68  }
69  }
70 }
71 
72 /// Helper method to inspect affine expressions. Rejects cases where the
73 /// same index is used more than once. Also rejects compound affine
74 /// expressions in sparse dimensions.
75 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76  LevelType lt, bool setLvlFormat = true) {
77  switch (a.getKind()) {
78  case AffineExprKind::DimId: {
79  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
80  if (!isUndefLT(merger.getLvlType(tid, idx)))
81  return false; // used more than once
82  if (setLvlFormat)
83  merger.setLevelAndType(tid, idx, lvl, lt);
84  return true;
85  }
89  assert(lt.hasDenseSemantic());
90  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
91  // We do not set dim level format for affine expression like d0 + d1 on
92  // either loop index at d0 or d1. We continue the recursion merely to
93  // check whether current affine is admissible or not.
94  return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
95  findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
96  }
97  // Falls through when it is a constant Affine
98  return true;
99  }
100  default:
101  return false;
102  }
103 }
104 
105 /// Helper method to inspect affine expressions for index variable reduction
106 /// based codegen. It finds the dependent index set for all tensor levels in the
107 /// current expression we are generating.
108 ///
109 /// For example, when handling A[i+j][j+k], we build the two way mapping in
110 /// merger between (tensor, level) pairs and their dependent index variable set:
111 /// A_0 <=> [i, j] and A_1 <=> [j, k]
112 ///
113 /// It rejects cases (returns false)
114 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
115 /// 2nd, when multiplication is used in the non-trivial index expression.
116 /// 3rd, when a constant operand is used in the non-trivial index expression.
117 ///
118 /// TODO: constant should be easy to handle.
119 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120  AffineExpr a, LevelType lt, bool isSubExp = false,
121  int64_t coefficient = 1) {
122  switch (a.getKind()) {
123  case AffineExprKind::DimId: {
124  // Only allow positive coefficients on AffineDimExpr.
125  if (coefficient <= 0)
126  return false;
127 
128  const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
129  if (!isUndefLT(merger.getLvlType(tensor, idx)))
130  return false; // used more than once, e.g., A[i][i]
131 
132  // TODO: Generalizes the following two cases. A[i] (with trivial index
133  // expression) can be treated as a special affine index expression. We do
134  // not necessarily need to differentiate them.
135  if (!isSubExp) {
136  assert(coefficient == 1);
137  merger.setLevelAndType(tensor, idx, lvl, lt);
138  }
139 
140  if (isSubExp) {
141  // The current loops appears in more than one affine expressions on the
142  // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143  // used twice.
144  if (merger.hasDependentLvl(idx, tensor)) {
145  // TODO: This can be supported by coiterate slices if the loop idx is
146  // appeared on affine index for different tensor, or take slice on
147  // multiple dimensions when it is on the same tensor.
148  // E.g.,
149  // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150  // d0_1 = getNextSliceOffset t0 along lvl0
151  // d0_2 = getNextSliceOffset t1 along lvl0
152  // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153  // else increase min(d0_1, d0_2).
154  return false;
155  }
156  merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
157  }
158  return true;
159  }
161  case AffineExprKind::Mul: {
162  // TODO: Support index expression like `2 * d0`, we now only support more
163  // complicated cases like `2 * d0 + d1`.
164  if (!isSubExp)
165  return false;
166 
167  // TODO: Support Constant AffineExp for slice-based codegen
168  if (isa<AffineConstantExpr>(a))
169  llvm_unreachable("Not yet implemented");
170 
171  auto binOp = cast<AffineBinaryOpExpr>(a);
172  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173  if (isa<AffineConstantExpr>(rhs))
174  std::swap(lhs, rhs);
175  // Must be in form of `constant * d`.
176  assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177  int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
178  return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
179  }
180  case AffineExprKind::Add: {
181  auto binOp = cast<AffineBinaryOpExpr>(a);
182  return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
183  findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
184  }
185  default:
186  return false;
187  }
188 }
189 
190 /// Gets the total number of compound affine expressions in the
191 /// `getMatchingIndexingMap` for the given tensor. For the following inputs:
192 ///
193 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194 ///
195 /// Returns 1 (because the first level is compressed and its corresponding
196 /// indexing-expression is `d0 + d1`)
198  Value tensor) {
199  // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200  // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201  // However, we don't need to handle `StorageSpecifierType`, so we
202  // can use `SparseTensorType` once we guard against non-tensors.
203  const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204  if (!rtp)
205  return 0;
206  const SparseTensorType stt(rtp);
207 
208  const Level lvlRank = stt.getLvlRank();
209  const auto exprs = map.getResults();
210  assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211  "AffineMap does not have dimension-rank many results");
212  unsigned num = 0;
213  for (Level l = 0; l < lvlRank; l++) {
214  if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215  num++;
216  }
217  return num;
218 }
219 
220 /// Gets the total number of sparse levels with compound affine
221 /// expressions, summed over all operands of the `GenericOp`.
222 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223  unsigned num = 0;
224  for (OpOperand &t : op->getOpOperands())
225  num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226  t.get());
227  return num;
228 }
229 
230 // Returns true iff output has nontrivial affine indices.
231 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232  OpOperand *out = op.getDpsInitOperand(0);
233  if (getSparseTensorType(out->get()).isAllDense())
234  return false;
235  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236  out->get());
237 }
238 
239 /// Helper method to inspect sparse encodings in the tensor types.
240 /// Fills the per-dimension sparsity information for all tensors.
241 /// Returns true if the sparse annotations and affine subscript
242 /// expressions of all tensors are admissible. Returns false if
243 /// no annotations are found or inadmissible constructs occur.
244 /// We currently support two different ways to handle non-trivial index
245 /// expression on sparse tensors, and they accept different affine expressions.
246 /// When using dependent index reducton-based approach, it currently only
247 /// supports affine addition index expression.
248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249  bool annotated = false;
250  for (OpOperand &t : env.op()->getOpOperands()) {
251  const TensorId tid = env.makeTensorId(t.getOperandNumber());
252  const auto map = env.op().getMatchingIndexingMap(&t);
253  const auto enc = getSparseTensorEncoding(t.get().getType());
254  if (enc)
255  annotated = true;
256  const Level lvlRank = map.getNumResults();
257  assert(!enc || lvlRank == enc.getLvlRank());
258  assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259  // We only need to do index reduction if there is at least one
260  // non-trivial index expression on sparse levels. If all non-trivial
261  // index expression is on dense levels, we can efficiently rely on
262  // the random access to locate the element.
263  bool needIdxReduc =
264  enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265  // If then current tensor being inspected requires affine index, it need
266  // to be sliced.
267  for (Level l = 0; l < lvlRank; l++) {
268  const AffineExpr a = map.getResult(l);
269  const LevelType lt = enc.getLvlType(l);
270  if (idxReducBased && needIdxReduc) {
271  if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272  return false; // inadmissible affine expression
273  } else {
274  if (!findAffine(env.merger(), tid, l, a, lt))
275  return false; // inadmissible affine expression
276  }
277  }
278  }
279  return annotated;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Sparsifier synthesis methods (statements and expressions).
284 //===----------------------------------------------------------------------===//
285 
286 /// Local bufferization of all dense and sparse data structures.
287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288  linalg::GenericOp op = env.op();
289  Location loc = op.getLoc();
290  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291 
292  SmallVector<Range, 4> loopRange =
293  llvm::cast<linalg::LinalgOp>(op.getOperation())
294  .createLoopRanges(builder, loc);
295 
297  builder, loc,
298  /// Generates buffer for the output tensor.
299  /// Note that all sparse kernels assume that when all elements are written
300  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301  /// to all zeroes and only nonzeroes values are computed and written out.
302  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303  /// for the updates and no assumption on the original contents of the
304  /// output buffer is necessary.
305  [&op](OpBuilder &builder, Location loc, Value memref,
306  Value tensor) -> Value {
307  // Must not be a sparse tensor.
308  assert(!getSparseTensorEncoding(tensor.getType()));
309  // Two output tensor references should point to the same object.
310  OpOperand *lhs = op.getDpsInitOperand(0);
311  assert(lhs->get() == tensor);
312  // An output tensor can simply materialize from the buffer of the tensor
313  // that appears in the outs() clause. For updates, this has the
314  // advantage that only the nonzero value are involved in the
315  // computation, keeping the operation O(nnz). In all other cases, we are
316  // forced to zero out the buffer to enforce the assumption above, which
317  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318  // O(nnz) for matrices).
319  // TODO: use better analysis to avoid zeroing out the buffer?
320  bool isInit = op.isInitTensor(lhs);
321  Value init = memref;
322  if (!isInit) {
323  Value zero = constantZero(builder, loc,
324  getElementTypeOrSelf(tensor.getType()));
325  builder.create<linalg::FillOp>(loc, ValueRange{zero},
326  ValueRange{init});
327  }
328  return init;
329  },
330  [&loopRange](OpBuilder &b, Location loc, Level l) {
331  assert(l < loopRange.size());
332  return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333  });
334 }
335 
336 /// Generates index for load/store on sparse tensor.
337 static Value genIndex(CodegenEnv &env, OpOperand *t) {
338  const auto map = env.op().getMatchingIndexingMap(t);
339  const auto stt = getSparseTensorType(t->get());
340  const Level lvlRank = stt.getLvlRank();
341  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342  const AffineExpr a = map.getResult(lvlRank - 1);
343  assert(a.getKind() == AffineExprKind::DimId);
344  const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345  return env.getLoopVar(idx);
346 }
347 
348 /// Generates subscript for load/store on a dense or sparse tensor.
349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350  SmallVectorImpl<Value> &args) {
351  const Location loc = env.op().getLoc();
352  const TensorId tid = env.makeTensorId(t->getOperandNumber());
353  const auto map = env.op().getMatchingIndexingMap(t);
354  const auto stt = getSparseTensorType(t->get());
355  if (stt.hasEncoding()) {
356  // For sparse tensors we only push the last-level's position onto `args`.
357  const auto pos = env.emitter().getValPosits(tid);
358  assert(!pos.empty());
359  args.append(pos);
360  // Simply returns the tensor to extract value using iterators.
362  return t->get();
363  } else {
364  // For dense tensors we push all level's coordinates onto `args`.
365  const Level lvlRank = stt.getLvlRank();
366  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
367  for (Level l = 0; l < lvlRank; l++) {
368  const auto lvlExpr = map.getResult(l);
369  const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
370  args.push_back(lvlCrd);
371  }
372  }
373  return env.emitter().getValBuffer()[tid];
374 }
375 
376 /// Generates insertion code to implement dynamic tensor load.
378  OpOperand *t) {
379  linalg::GenericOp op = env.op();
380  Location loc = op.getLoc();
381  // Direct lexicographic coordinate order, tensor loads as zero.
382  if (!env.isExpand()) {
383  Type tp = getElementTypeOrSelf(t->get().getType());
384  return constantZero(builder, loc, tp);
385  }
386  // Load from expanded access pattern.
387  Value index = genIndex(env, t);
388  return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
389 }
390 
391 /// Generates insertion code to implement dynamic tensor load for reduction.
393  OpOperand *t) {
394  linalg::GenericOp op = env.op();
395  Location loc = op.getLoc();
396  Value identity = env.getCustomRedId();
397  // Direct lexicographic coordinate order, tensor loads as identity.
398  if (!env.isExpand())
399  return identity;
400  // Load from expanded access pattern if filled, identity otherwise.
401  Value values = env.getExpandValues();
402  Value filled = env.getExpandFilled();
403  Value index = genIndex(env, t);
404  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
405  Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
406  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
407 }
408 
409 static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
410  Value sparseOut, ValueRange ivs, Value v) {
411  scf::IfOp condInsert =
412  builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
413  // True branch.
414  builder.setInsertionPointToStart(condInsert.thenBlock());
415  Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
416  builder.create<scf::YieldOp>(loc, res);
417  // False branch.
418  builder.setInsertionPointToStart(condInsert.elseBlock());
419  builder.create<scf::YieldOp>(loc, sparseOut);
420  // Value assignment.
421  builder.setInsertionPointAfter(condInsert);
422  return condInsert.getResult(0);
423 }
424 
425 /// Generates insertion code to implement dynamic tensor store.
426 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
427  Value rhs) {
428  linalg::GenericOp op = env.op();
429  Location loc = op.getLoc();
430  // Direct insertion in lexicographic coordinate order.
431  if (!env.isExpand()) {
432  const LoopId numLoops = op.getRank(t);
433  // Retrieves the first `numLoop` induction variables.
434  SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
435  env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
436  Value chain = env.getInsertionChain();
437  if (env.isValidLexInsert()) {
438  // Generates runtime check for a valid lex during reduction,
439  // to avoid inserting the identity value for empty reductions.
440  // if (validLexInsert) then
441  // insert(rhs) into chain
442  // return updated chain
443  // else
444  // return unmodified chain
445  Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
446  chain, ivs, rhs);
447  env.updateInsertionChain(out);
448  } else {
449  Value sparseOut;
450  if (!hasAnySparseType(env.op().getInputs().getTypes())) {
451  // This is an all-dense -> sparse kernel, test rhs != 0 before
452  // insertion.
453  Value nz = genIsNonzero(builder, loc, rhs);
454  sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
455  } else {
456  sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
457  }
458  // Generates regular insertion chain.
459  env.updateInsertionChain(sparseOut);
460  }
461  return;
462  }
463  // Generates insertion code along expanded access pattern.
464  // if (!expFilled[i]) then
465  // expFilled[i] = true
466  // expAdded[inserts++] = i
467  // endif
468  // values[i] = rhs
469  Value values = env.getExpandValues();
470  Value filled = env.getExpandFilled();
471  Value added = env.getExpandAdded();
472  Value count = env.getExpandCount();
473  Value index = genIndex(env, t);
474  Value fval = constantI1(builder, loc, false);
475  Value tval = constantI1(builder, loc, true);
476  // If statement.
477  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
478  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
479  isFilled, fval);
480  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
481  /*else=*/true);
482  // True branch.
483  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
484  builder.create<memref::StoreOp>(loc, tval, filled, index);
485  builder.create<memref::StoreOp>(loc, index, added, count);
486  Value one = constantIndex(builder, loc, 1);
487  Value add = builder.create<arith::AddIOp>(loc, count, one);
488  builder.create<scf::YieldOp>(loc, add);
489  // False branch.
490  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
491  builder.create<scf::YieldOp>(loc, count);
492  builder.setInsertionPointAfter(ifOp);
493  // Value assignment.
494  env.updateExpandCount(ifOp.getResult(0));
495  builder.create<memref::StoreOp>(loc, rhs, values, index);
496 }
497 
498 /// Generates a load on a dense or sparse tensor.
499 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
500  // Test if the load was hoisted to a higher loop nest.
501  Value val = env.exp(exp).val;
502  if (val)
503  return val;
504  // Get tensor operand.
505  linalg::GenericOp op = env.op();
506  Location loc = op.getLoc();
507  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
508  // Fold binary-valued tensor into explicit value.
509  const auto stt = getSparseTensorType(t->get());
510  if (auto explVal = stt.getExplicitVal())
511  return genValFromAttr(builder, loc, explVal);
512  // Load during insertion.
513  if (env.isSparseOutput(t)) {
514  if (env.isCustomReduc())
515  return genInsertionLoadReduce(env, builder, t);
516  return genInsertionLoad(env, builder, t);
517  }
518 
519  // Actual load.
520  SmallVector<Value> args;
521  Value ptr = genSubscript(env, builder, t, args);
522  if (llvm::isa<TensorType>(ptr.getType())) {
523  assert(env.options().sparseEmitStrategy ==
525  args.size() == 1);
526  return builder.create<ExtractValOp>(loc, ptr, args.front());
527  }
528  return builder.create<memref::LoadOp>(loc, ptr, args);
529 }
530 
531 /// Generates a store on a dense or sparse tensor.
532 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
533  Value rhs) {
534  // Only unary and binary are allowed to return an uninitialized rhs
535  // to indicate missing output. Or otherwise a custom reduction that
536  // received no value to accumulate.
537  if (!rhs) {
538  assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
539  env.exp(exp).kind == TensorExp::Kind::kBinary ||
540  env.exp(exp).kind == TensorExp::Kind::kReduce);
541  return;
542  }
543  // Test if this is a scalarized reduction.
544  if (env.isReduc()) {
545  env.updateReduc(rhs);
546  return;
547  }
548  // Regular store.
549  linalg::GenericOp op = env.op();
550  Location loc = op.getLoc();
551  OpOperand *t = op.getDpsInitOperand(0);
552  if (!env.isSparseOutput(t)) {
553  SmallVector<Value> args;
554  Value ptr = genSubscript(env, builder, t, args);
555  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
556  return;
557  }
558  // Store during sparse insertion.
559  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
560  genInsertionStore(env, builder, t, rhs);
561  return;
562  }
563  // Select operation insertion.
564  Value chain = env.getInsertionChain();
565  scf::IfOp ifOp =
566  builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
567  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
568  // Existing value was preserved to be used here.
569  assert(env.exp(exp).val);
570  Value v0 = env.exp(exp).val;
571  genInsertionStore(env, builder, t, v0);
572  env.merger().clearExprValue(exp);
573  // Yield modified insertion chain along true branch.
574  Value mchain = env.getInsertionChain();
575  builder.create<scf::YieldOp>(op.getLoc(), mchain);
576  // Yield original insertion chain along false branch.
577  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
578  builder.create<scf::YieldOp>(loc, chain);
579  // Done with if statement.
580  env.updateInsertionChain(ifOp->getResult(0));
581  builder.setInsertionPointAfter(ifOp);
582 }
583 
584 /// Generates an invariant value.
585 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
586  return env.exp(exp).val;
587 }
588 
589 /// Semi-ring branches are simply inlined by the sparsifier. Prior
590 /// analysis has verified that all computations are "local" to the inlined
591 /// branch or otherwise invariantly defined outside the loop nest, with the
592 /// exception of index computations, which need to be relinked to actual
593 /// inlined cloned code.
594 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
595  Value e) {
596  if (auto arg = dyn_cast<BlockArgument>(e)) {
597  // Direct arguments of the original linalg op must be converted
598  // into dense tensor loads. Note that we should not encounter
599  // anything else. This needs to be verified by semi-ring ops.
600  linalg::GenericOp op = env.op();
601  if (arg.getOwner()->getParentOp() == op) {
602  const TensorId tid = env.makeTensorId(arg.getArgNumber());
603  OpOperand *t = &op->getOpOperand(tid);
604  assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
605  SmallVector<Value> args;
606  Value ptr = genSubscript(env, rewriter, t, args);
607  return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
608  }
609  } else if (Operation *def = e.getDefiningOp()) {
610  // Handle index computation.
611  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
612  return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
613  // When still defined in new body, recurse into operands.
614  if (def->getBlock() == block) {
615  rewriter.setInsertionPoint(def);
616  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
617  rewriter.modifyOpInPlace(def, [&]() {
618  def->setOperand(
619  i, relinkBranch(env, rewriter, block, def->getOperand(i)));
620  });
621  }
622  }
623  }
624  return e;
625 }
626 
627 /// Recursively generates tensor expression.
628 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
630  return Value();
631 
632  linalg::GenericOp op = env.op();
633  Location loc = op.getLoc();
634  const TensorExp &exp = env.exp(e);
635  const auto kind = exp.kind;
636  if (kind == TensorExp::Kind::kTensor)
637  return genTensorLoad(env, rewriter, e);
638  if (kind == TensorExp::Kind::kInvariant)
639  return genInvariantValue(env, e);
640  if (kind == TensorExp::Kind::kLoopVar)
641  return env.getLoopVar(exp.loop);
642 
643  if (kind == TensorExp::Kind::kReduce)
644  env.startCustomReduc(e); // enter custom
645 
646  // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
647  // based on the type of the other operand.
648  Value v0, v1;
651  v1 = genExp(env, rewriter, exp.children.e1);
652  v0 = constantZero(rewriter, loc, v1.getType());
655  v0 = genExp(env, rewriter, exp.children.e0);
656  v1 = constantZero(rewriter, loc, v0.getType());
657  } else {
658  v0 = genExp(env, rewriter, exp.children.e0);
659  v1 = genExp(env, rewriter, exp.children.e1);
660  }
661 
662  Value ee;
663  if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
664  // custom reduce did not receive a value
665  } else {
666  ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
667  if (ee &&
670  kind == TensorExp::Kind::kReduce ||
671  kind == TensorExp::Kind::kSelect)) {
672  OpBuilder::InsertionGuard guard(rewriter);
673  ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
674  }
675  }
676 
677  if (kind == TensorExp::Kind::kReduce)
678  env.endCustomReduc(); // exit custom
679 
680  if (kind == TensorExp::Kind::kSelect)
681  env.merger().setExprValue(e, v0); // Preserve value for later use.
682 
683  return ee;
684 }
685 
686 /// Hoists loop invariant tensor loads for which indices have been exhausted.
687 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
688  LoopId curr, bool isStart) {
690  return;
691  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
692  // Inspect tensor indices.
693  linalg::GenericOp op = env.op();
694  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
695  const auto map = op.getMatchingIndexingMap(&t);
696  const auto stt = getSparseTensorType(t.get());
697  const Level lvlRank = stt.getLvlRank();
698  assert(static_cast<Level>(map.getNumResults()) == lvlRank);
699  bool isCurrentLoop = curr == 0; // for scalar tensors
700  for (Level l = 0; l < lvlRank; l++) {
701  const AffineExpr a = map.getResult(l);
702  if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
703  return; // still in play
704  }
705  // All exhausted at current level.
706  if (!isCurrentLoop)
707  return;
708  // Generate code for a scalarized reduction or invariant. Note that
709  // because custom reduction lhs may occur several times in the IR,
710  // we have a built-in safety for only initializing and wrapping-up
711  // the scalarized reduction once.
712  OpOperand *lhs = op.getDpsInitOperand(0);
713  if (lhs == &t) {
714  // Start or end a scalarized reduction.
715  if (isStart) {
716  if (env.isCustomReduc()) {
717  if (!env.isReduc())
718  env.startReduc(exp, env.getCustomRedId());
719  } else {
720  env.startReduc(exp, genTensorLoad(env, builder, exp));
721  }
722  if (env.hasSparseOutput())
724  constantI1(builder, env.op().getLoc(), false));
725  } else {
726  if (!env.isCustomReduc() || env.isReduc())
727  genTensorStore(env, builder, exp, env.endReduc());
728  if (env.hasSparseOutput())
729  env.endValidLexInsert();
730  }
731  } else {
732  // Start or end loop invariant hoisting of a tensor load.
733  if (isStart) {
734  env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
735  } else {
736  env.merger().clearExprValue(exp);
737  }
738  }
739  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
740  env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
741  env.exp(exp).kind != TensorExp::Kind::kSynZero) {
742  // Traverse into the binary operations. Note that we only hoist
743  // tensor loads, since subsequent MLIR/LLVM passes know how to
744  // deal with all other kinds of derived loop invariants.
745  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
746  env.startCustomReduc(exp); // enter custom
747  const ExprId e0 = env.exp(exp).children.e0;
748  const ExprId e1 = env.exp(exp).children.e1;
749  genInvariants(env, builder, e0, curr, isStart);
750  genInvariants(env, builder, e1, curr, isStart);
751  if (env.exp(exp).kind == TensorExp::Kind::kReduce)
752  env.endCustomReduc(); // exit custom
753  }
754 }
755 
756 /// Generates an expanded access pattern in innermost dimension.
757 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
758  bool isStart) {
759  linalg::GenericOp op = env.op();
760  OpOperand *lhs = op.getDpsInitOperand(0);
761  if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
762  return; // not needed at current level
763  assert(!env.isReduc());
764  // Generate start or end of an expanded access pattern. Note that because
765  // an expansion does not rely on the ongoing contents of the sparse storage
766  // scheme, we can use the original tensor as incoming SSA value (which
767  // simplifies codegen a bit). If expansion on the actual contents is ever
768  // needed, we will need to use the SSA value in the insertion chain instead.
769  Value tensor = lhs->get();
770  Location loc = op.getLoc();
771  if (isStart) {
772  auto dynShape = {ShapedType::kDynamic};
773  Type etp = cast<ShapedType>(tensor.getType()).getElementType();
774  Type t1 = MemRefType::get(dynShape, etp);
775  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
776  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
777  Type t4 = builder.getIndexType();
778  auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
779  assert(r.getNumResults() == 4);
780  env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
781  r.getResult(3));
782  } else {
783  SmallVector<Value> indices;
784  for (LoopId i = 0; i < curr; i++)
785  indices.push_back(env.emitter().getLoopIV(i));
786  Value values = env.getExpandValues();
787  Value filled = env.getExpandFilled();
788  Value added = env.getExpandAdded();
789  Value count = env.getExpandCount();
790  Value chain = env.getInsertionChain();
791  Value compress = builder.create<CompressOp>(loc, values, filled, added,
792  count, chain, indices);
793  env.updateInsertionChain(compress);
794  env.endExpand();
795  }
796 }
797 
798 /// Returns parallelization strategy. Any implicit loop in the Linalg
799 /// operation that is marked "parallel" is a candidate. Whether it is actually
800 /// converted to a parallel operation depends on the requested strategy.
801 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
802  // Reject parallelization of sparse output.
803  if (env.hasSparseOutput())
804  return false;
805  // Parallel loops on tensor expansion can cause data races.
806  if (env.isExpand())
807  return false;
808  // Inspect strategy.
809  switch (env.options().parallelizationStrategy) {
811  return false;
813  return isOuter && !isSparse;
815  return isOuter;
817  return !isSparse;
819  return true;
820  }
821  llvm_unreachable("unexpected parallelization strategy");
822 }
823 
824 /// Whether or not the current loop being generated should be parallized (if
825 /// possible) according to the configuration.
826 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
827  ArrayRef<TensorLevel> tidLvls) {
828  linalg::GenericOp op = env.op();
829  auto iteratorTypes = op.getIteratorTypesArray();
830  bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
831  // Queries the LT based on the tensor and loop id, as requested by
832  // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
833  // should be consistent with the LT indexed by <TensorId, Level>.
834  const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
835  return lt.hasSparseSemantic();
836  });
837  return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
838 }
839 
840 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
841 /// can either be a for loop or while loop depending on whether there is at most
842 /// one sparse level in the list.
844  ArrayRef<TensorLevel> tidLvls,
845  unsigned numCases, bool tryParallel,
846  bool needsUniv) {
847  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
848  // Construct while-loop with a parameter for each index.
849  return env.emitter().enterCoIterationOverTensorsAtLvls(
850  builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
851  needsUniv);
852  });
853  assert(loop);
854  return loop;
855 }
856 
857 /// Generates a for-loop or a while-loop, depending on whether it implements
858 /// singleton iteration or co-iteration over the given conjunction.
859 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
860  unsigned numCases, bool needsUniv,
861  ArrayRef<TensorLevel> tidLvls) {
862  bool tryParallel = shouldTryParallize(env, curr, tidLvls);
863  return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
864  needsUniv);
865 }
866 
867 /// Generates the induction structure for a while-loop.
868 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
869  bool needsUniv) {
870  Location loc = env.op().getLoc();
871  // Finalize each else branch of all if statements.
872  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
873  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
874  builder.getInsertionBlock()->getParentOp())) {
875  // Break on IfOp for slicing filtering.
876  if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
877  StringAttr::get(ifOp->getContext(), "slice"))
878  break;
879 
880  unsigned y = 0;
881  SmallVector<Value> yields;
882  if (env.isReduc()) {
883  yields.push_back(env.getReduc());
884  env.updateReduc(ifOp.getResult(y++));
885  if (env.isValidLexInsert()) {
886  yields.push_back(env.getValidLexInsert());
887  env.updateValidLexInsert(ifOp.getResult(y++));
888  }
889  }
890  if (env.isExpand()) {
891  yields.push_back(env.getExpandCount());
892  env.updateExpandCount(ifOp->getResult(y++));
893  }
894  if (env.getInsertionChain()) {
895  yields.push_back(env.getInsertionChain());
896  env.updateInsertionChain(ifOp->getResult(y++));
897  }
898  assert(y == yields.size());
899  builder.create<scf::YieldOp>(loc, yields);
900  builder.setInsertionPointAfter(ifOp);
901  }
902  }
903  // No need to set the insertion point here as LoopEmitter keeps track of the
904  // basic block where scf::Yield should be inserted.
905 }
906 
907 /// Generates a case region in the coiterate operation.
908 static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
909  unsigned caseIdx, LatPointId allCase,
910  LatPointId curCase,
911  MutableArrayRef<Value> reduc) {
912  assert(allCase == curCase || env.merger().latGT(allCase, curCase));
913  const BitVector &allCaseBits = env.merger().lat(allCase).simple;
914  const BitVector &curCaseBits = env.merger().lat(curCase).simple;
915 
916  /// Computes the subset of iterators that are valid in the current case being
917  /// generated.
918  I64BitSet caseBit(0);
919  for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
920  if (curCaseBits.test(set))
921  caseBit.set(idx);
922 
923  env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
924  caseIdx, reduc);
925 }
926 
927 /// Generates a single if-statement within a while-loop.
928 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
929  LatPointId p) {
930  Location loc = env.op().getLoc();
931  SmallVector<Type> types;
932  Value cond;
934  p, /*simple=*/true,
935  [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
936  bool isIdxRed) {
937  if (isIdxRed) {
938  // Since there is no 1:1 mapping from loop to level (multiple loops
939  // are required to resolve one level with non-trivial index
940  // expression), we need to reconstruct the tensor level types if this
941  // loop requires index reduction condition.
942  assert(lvl.has_value() && isUndefLT(lt));
943  auto stt = getSparseTensorType(env.op().getInputs()[tid]);
944  lt = stt.getLvlType(*lvl);
945  }
946  assert(curr == env.merger().loop(b));
947  Value clause;
948  if (lt.hasSparseSemantic()) {
949  assert(lvl.has_value());
950  const Value crd = env.emitter().getCoord(tid, *lvl);
951  const Value lvar = env.getLoopVar(curr);
952  clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
953  crd, lvar);
954  } else {
955  assert(lt.hasDenseSemantic() || isUndefLT(lt));
956  clause = constantI1(builder, loc, true);
957  }
958  cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
959  });
960  if (env.isReduc()) {
961  types.push_back(env.getReduc().getType());
962  if (env.isValidLexInsert())
963  types.push_back(env.getValidLexInsert().getType());
964  }
965  if (env.isExpand())
966  types.push_back(builder.getIndexType());
967  if (env.getInsertionChain())
968  types.push_back(env.getInsertionChain().getType());
969  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
970  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971  return ifOp;
972 }
973 
974 /// Generates end of true branch of if-statement within a while-loop.
975 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
976  Value redInput, Value cntInput, Value insInput,
977  Value validIns) {
978  SmallVector<Value> operands;
979  if (env.isReduc()) {
980  operands.push_back(env.getReduc());
981  env.updateReduc(redInput);
982  if (env.isValidLexInsert()) {
983  // Any overlapping indices during a reduction creates a valid lex insert.
984  operands.push_back(constantI1(builder, env.op().getLoc(), true));
985  env.updateValidLexInsert(validIns);
986  }
987  }
988  if (env.isExpand()) {
989  operands.push_back(env.getExpandCount());
990  env.updateExpandCount(cntInput);
991  }
992  if (env.getInsertionChain()) {
993  operands.push_back(env.getInsertionChain());
994  env.updateInsertionChain(insInput);
995  }
996  if (!operands.empty())
997  builder.create<scf::YieldOp>(env.op().getLoc(), operands);
998  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // Sparsifier synthesis methods (loop sequence).
1003 //===----------------------------------------------------------------------===//
1004 
1006  CodegenEnv &env, LatPointId li, LoopId curr,
1007  llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1008  const BitVector &simple = env.lat(li).simple;
1009  const TensorId outTid = env.merger().getOutTensorID();
1010  const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1011 
1012  unsigned numloopCond = 0;
1013  bool hasNonUnique = false;
1015  li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1016  LevelType lt, bool isIdxReduc) {
1017  if (simple[b]) {
1018  if (isIdxReduc) {
1019  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1020  numloopCond++;
1021  return;
1022  }
1023  if (isUndefLT(lt)) {
1024  // An undefined lt in the lattices, we probably mean to
1025  // generate a dense loop according to the synthetic tensor (for
1026  // invariants and sparse output tensor).
1027  if (env.merger().getSynTensorID() == tid) {
1028  // Coiterating with an invariant
1029  // e.g., out = prod(in[i][j] op invariant);
1030  // or a broadcast
1031  // e.g., out[i][j] = in[i] (j is undef for input)
1032  //
1033  // The level of the synthetic tensor is the current loop depth;
1034  // the rank of the synthetic tensor equals to number of loops.
1035  assert(curr == env.getCurrentDepth());
1036  lvl = curr;
1037  } else if (!lvl) {
1038  // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1039  return;
1040  }
1041  }
1042  hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1043  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1044  numloopCond++;
1045  } else if (lt.hasDenseSemantic() || isIdxReduc) {
1046  callback(env.makeTensorLevel(tid, *lvl), nullptr);
1047  } else {
1048  assert(isUndefLT(lt));
1049  linalg::GenericOp op = env.op();
1050  if (tid >= op.getNumDpsInputs())
1051  // We only handle affine expression on input tensors (for now).
1052  return;
1053  OpOperand *operand = &op->getOpOperand(tid);
1054  const auto stt = getSparseTensorType(operand->get());
1055  // Non-annotated dense tensors requires no special handling.
1056  if (!stt.hasEncoding())
1057  return;
1058 
1059  ArrayRef<AffineExpr> affines =
1060  op.getMatchingIndexingMap(operand).getResults();
1061  const Level lvlRank = stt.getLvlRank();
1062  assert(affines.size() == static_cast<size_t>(lvlRank));
1063  for (Level l = 0; l < lvlRank; l++) {
1064  AffineExpr exp = affines[l];
1065  // Skip simple affine expression and non-dense levels (which
1066  // have their own filter loop).
1067  LevelType lt = stt.getLvlType(l);
1068  if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1069  continue;
1070 
1071  // Constant affine expression are handled in genLoop.
1072  if (!isa<AffineConstantExpr>(exp)) {
1073  bool isCurrentLoop = false;
1074  assert(curr == env.getCurrentDepth());
1075  if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1076  isCurrentLoop) {
1077  // If the compound affine is invariant and we are right at the
1078  // level. We need to generate the address according to the
1079  // affine expression. This is also the best place we can do it
1080  // to avoid putting it inside inner loops.
1081  callback(env.makeTensorLevel(tid, l), exp);
1082  }
1083  }
1084  }
1085  }
1086  });
1087 
1088  if (isDenseLT(env.lt(outTid, curr))) {
1089  auto stt = getSparseTensorType(env.op().getOutputs().front());
1090  // Note that we generate dense indices of the output tensor unconditionally,
1091  // since they may not appear in the lattice, but may be needed for
1092  // linearized env.
1093  // TODO: we should avoid introducing corner cases for all-dense sparse
1094  // tensors.
1095  if (stt.hasEncoding() && stt.isAllDense())
1096  callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1097  }
1098 
1099  if (numloopCond == 0) {
1100  // Corner cases where the loop bound is defined by a *unused* operand, in
1101  // this case, we just generate a dense "fake" loop by iterating over the
1102  // synthetic tensor.
1103  callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1104  numloopCond++;
1105  }
1106  // If we just need to one loop conditions and the conditions is not imposed on
1107  // non-unique level, the loop can be generated by a for loop.
1108  // Or, if we are generating sparse-iterator-based loops, we always generate
1109  // `sparse_tensor.iterate` regardless whether the level is unique or not.
1110  return numloopCond == 1 &&
1111  (!hasNonUnique || env.options().sparseEmitStrategy ==
1113 }
1114 
1115 /// Starts a loop sequence at given level. Returns true if
1116 /// the universal loop index must be maintained at this level.
1117 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1118  LoopId curr, LatSetId lts) {
1119  assert(!env.getLoopVar(curr));
1120  // Emit invariants at this loop sequence level.
1121  genInvariants(env, builder, exp, curr, /*isStart=*/true);
1122  // Emit access pattern expansion for sparse tensor output.
1123  genExpand(env, builder, curr, /*isStart=*/true);
1124  // Emit further initialization at this loop sequence level.
1125  const LatPointId l0 = env.set(lts)[0];
1126 
1127  SmallVector<TensorLevel> tidLvls;
1128  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1129  // TODO: remove this! The same tensor level might be added for multiple
1130  // times due to the special handling for all-dense "sparse" output tensor
1131  // (see L1038).
1132  if (llvm::find(tidLvls, tl) != tidLvls.end())
1133  return;
1134  tidLvls.emplace_back(tl);
1135  });
1136 
1137  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1138 
1139  // Maintain the universal index only if it is actually
1140  // consumed by a subsequent lattice point.
1141  for (const LatPointId li : env.set(lts).drop_front())
1142  if (!env.merger().hasAnySparse(env.lat(li).simple))
1143  return true;
1144 
1145  return false;
1146 }
1147 
1148 // Generates dense affine address for encoding.
1150  OpBuilder &builder, TensorId tid,
1151  Level startLvl) {
1152  // TODO: Handle affine expression on output tensor.
1153  linalg::GenericOp op = env.op();
1154  assert(tid < op.getNumDpsInputs());
1155  OpOperand *input = op.getDpsInputOperands()[tid];
1156  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1157  const auto enc = getSparseTensorEncoding(input->get().getType());
1158  if (enc) {
1159  const Location loc = op.getLoc();
1160  const TensorId tid = env.makeTensorId(input->getOperandNumber());
1161  const Level lvlRank = enc.getLvlRank();
1162  assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1163  for (Level l = startLvl; l < lvlRank; l++) {
1164  AffineExpr lvlExpr = lvlExprs[l];
1165  if (enc.getLvlType(l).hasDenseSemantic() &&
1166  isa<AffineConstantExpr>(lvlExpr))
1168  builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1169  else
1170  return; // break on first non-dense non-constant level
1171  }
1172  }
1173 }
1174 
1175 // We can generate address for constant affine expression before any loops
1176 // starting from the first level as they do not depend on anything.
1177 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1178 // levels can be determined before loops.
1180  RewriterBase &rewriter) {
1181  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1182  genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1183 }
1184 
1185 /// Returns true if the lattice bit can be iterated by a for loop.
1187  CodegenEnv &env, LatPointId li, LoopId curr,
1189  SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1190  return getAllTidLvlsInLatPoints(env, li, curr,
1191  [&](TensorLevel tl, AffineExpr exp) {
1192  if (exp)
1193  affineTidLvls.emplace_back(tl, exp);
1194  else
1195  tidLvls.emplace_back(tl);
1196  });
1197 }
1198 
1199 /// Starts a single loop in current sequence.
1200 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1201  OpBuilder &builder, LoopId curr,
1202  LatPointId li, unsigned numCases,
1203  bool needsUniv) {
1204  // TODO: numCases only used when generating iterator-based loops. Cleanup
1205  // after fully migration.
1206  // The set of tensors + lvls to generate loops on
1207  SmallVector<TensorLevel> tidLvls;
1208 
1209  // The set of dense tensors with non-trivial affine expression that just
1210  // becomes invariant and the address are generated at the current level.
1212  bool isSingleCond =
1213  translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1214 
1215  // Emit the for/while-loop control.
1216  Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1217  Location loc = env.op().getLoc();
1218  for (auto [tidLvl, exp] : affineTidLvls) {
1219  env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1220  }
1221 
1222  // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1223  // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1224  // on constant affines expression may now be determined.
1225  auto allTidLvls =
1226  llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1227  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1228  if (tid != env.merger().getOutTensorID() &&
1229  tid != env.merger().getSynTensorID())
1230  genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1231  }
1232 
1233  return std::make_pair(loop, isSingleCond);
1234 }
1235 
1236 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1237 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1238  LatPointId li, bool needsUniv, bool isSingleCond) {
1239  // Either a for-loop or a while-loop that iterates over a slice.
1240  if (isSingleCond) {
1241  // Any iteration creates a valid lex insert.
1242  if (env.isReduc() && env.isValidLexInsert())
1243  env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1244  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1245  // End a while-loop.
1246  finalizeWhileOp(env, rewriter, needsUniv);
1247  } else {
1248  needsUniv = false;
1249  }
1250  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1251  env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1252  return std::nullopt;
1253  });
1254  return needsUniv;
1255 }
1256 
1257 /// Ends a loop sequence at given level.
1258 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1259  unsigned at) {
1260  assert(!env.getLoopVar(at));
1261  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1262  // Unmark bookkeeping of invariants and loop index.
1263  genInvariants(env, builder, exp, at, /*isStart=*/false);
1264  // Finalize access pattern expansion for sparse tensor output.
1265  genExpand(env, builder, at, /*isStart=*/false);
1266 }
1267 
1268 /// Recursively generates code while computing iteration lattices in order
1269 /// to manage the complexity of implementing co-iteration over unions
1270 /// and intersections of sparse iterations spaces.
1271 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1272  LoopId curr) {
1273  assert(curr == env.getCurrentDepth());
1274 
1275  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1276  if (curr == env.getLoopNum()) {
1277  Value rhs = genExp(env, rewriter, exp);
1278  genTensorStore(env, rewriter, exp, rhs);
1279  return;
1280  }
1281 
1282  // Construct iteration lattices for current loop index.
1283  const LatSetId lts =
1284  env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1285 
1286  // Start a loop sequence.
1287  bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1288 
1289  // When using sparse-iterator-based loops, we only need one loops, as
1290  // opposed to a loop sequence, to cover all the iterator spaces.
1291  const unsigned lsize = env.set(lts).size();
1292  if (env.generatingSparseIterator()) {
1293  // Get the largest lattice point and start a loop.
1294  const LatPointId li = env.set(lts)[0];
1295  auto [loop, isSingleCond] =
1296  startLoop(env, rewriter, curr, li, lsize, needsUniv);
1297  assert(isSingleCond == llvm::isa<IterateOp>(loop));
1298  // We cannot change this to `for (const LatPointId li : env.set(lts))`
1299  // because the loop body causes data-movement which invalidates
1300  // the iterator.
1301  for (unsigned j = 0; j < lsize; j++) {
1302  const LatPointId lj = env.set(lts)[j];
1303  const ExprId ej = env.lat(lj).exp;
1304  // Recurse into body of each branch.
1305  if (!isSingleCond) {
1306  env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307  genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1308  genStmt(env, rewriter, ej, curr + 1);
1309  // TODO: handle yield values.
1310  assert(reduc.empty() && "Not Implemented");
1311  rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1312  return std::nullopt;
1313  });
1314  // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315  } else {
1316  genStmt(env, rewriter, ej, curr + 1);
1317  }
1318  }
1319  // End a loop.
1320  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1321  } else {
1322  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323  for (unsigned i = 0; i < lsize; i++) {
1324  const LatPointId li = env.set(lts)[i];
1325  // Start a loop.
1326  auto [loop, isSingleCond] =
1327  startLoop(env, rewriter, curr, li, lsize, needsUniv);
1328 
1329  // Visit all lattices points with Li >= Lj to generate the
1330  // loop-body, possibly with if statements for coiteration.
1331  Value redInput = env.getReduc();
1332  Value cntInput = env.getExpandCount();
1333  Value insInput = env.getInsertionChain();
1334  Value validIns = env.getValidLexInsert();
1335  // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1336  // because the loop body causes data-movement which invalidates the
1337  // iterator.
1338  for (unsigned j = 0; j < lsize; j++) {
1339  const LatPointId lj = env.set(lts)[j];
1340  const ExprId ej = env.lat(lj).exp;
1341  if (li == lj || env.merger().latGT(li, lj)) {
1342  // Recurse into body of each branch.
1343  if (!isSingleCond) {
1344  scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1345  genStmt(env, rewriter, ej, curr + 1);
1346  endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347  } else {
1348  genStmt(env, rewriter, ej, curr + 1);
1349  }
1350  }
1351  }
1352 
1353  // End a loop.
1354  needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1355  }
1356  }
1357 
1358  // End a loop sequence.
1359  endLoopSeq(env, rewriter, exp, curr);
1360  assert(curr == env.getCurrentDepth());
1361 }
1362 
1363 /// Converts the result computed by the sparse kernel into the required form.
1364 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1365  linalg::GenericOp op = env.op();
1366  OpOperand *lhs = op.getDpsInitOperand(0);
1367  Value tensor = lhs->get();
1368  Type resType = tensor.getType();
1369  if (getSparseTensorEncoding(resType)) {
1370  // The sparse tensor rematerializes from the original sparse tensor's
1371  // underlying sparse storage format. For an insertion chain, the
1372  // tensor materializes from the chain with 'hasInserts' enabled.
1373  bool hasInserts = false;
1374  if (Value chain = env.getInsertionChain()) {
1375  hasInserts = true;
1376  tensor = chain;
1377  }
1378  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1379  } else {
1380  // To rematerialize an non-annotated tensor, simply load it
1381  // from the bufferized value.
1382  Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1383  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1384  }
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // Sparsifier rewriting methods.
1389 //===----------------------------------------------------------------------===//
1390 
1391 namespace {
1392 
1393 /// Sparse rewriting rule for generic Lingalg operation.
1394 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1395 public:
1396  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1397  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1398 
1399  LogicalResult matchAndRewrite(linalg::GenericOp op,
1400  PatternRewriter &rewriter) const override {
1401  // Only accept single output operations with pure tensor semantics.
1402  if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1403  return failure();
1404 
1405  // Only accept trivial affine indices.
1407  return failure();
1408 
1409  // Only accept scheduled loops.
1410  if (!op->hasAttr("sorted")) {
1411  return rewriter.notifyMatchFailure(
1412  op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1413  "before sparsification.");
1414  }
1415 
1416  // Must have been demapped as well if the generic op is sorted.
1418 
1419  // Sets up a code generation environment.
1420  const unsigned numTensors = op->getNumOperands();
1421  const unsigned numLoops = op.getNumLoops();
1422  bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1423  // If we have indexing map like (d0) -> (0, d0), there might be more
1424  // levels then loops because of the constant index, that means we can not
1425  // use numLoops as the upper bound for ranks of all tensors.
1426  // TODO: Constant indices are currently not support on sparse tensor, but
1427  // are allowed in non-annotated dense tensor. Support it, it would be
1428  // required for sparse tensor slice rank reducing too.
1429  Level maxLvlRank = 0;
1430  for (auto operand : op.getOperands()) {
1431  if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1432  maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1433  }
1434  }
1435 
1436  // Detects sparse annotations and translates the per-level sparsity
1437  // information for all tensors to loop indices in the kernel.
1438  CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1439  if (!findSparseAnnotations(env, needIdxRed))
1440  return failure();
1441 
1442  // Only standard reduction operations (add, sub, or, xor) that can be
1443  // sparsified by merely reducing the stored values are admissible. More
1444  // elaborate reduction operations (such as mul, and, min, max) would need
1445  // to know whether implicit zeros occur as well. They can still be
1446  // implemented with a custom reduction operation, accepted here as well.
1447  if (op.getNumReductionLoops() > 0) {
1448  Operation *yield = op.getRegion().front().getTerminator();
1449  assert(isa<linalg::YieldOp>(yield));
1450  Operation *redop = yield->getOperand(0).getDefiningOp();
1451  if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1452  !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1453  !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1454  !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1455  !isa<ReduceOp>(redop)) {
1456  return failure();
1457  }
1458  }
1459 
1460  // Constructs the tensor expressions tree from `op`, returns failure if the
1461  // tree can not be built or the tensor expression is inadmissible.
1462  if (failed(env.initTensorExp()))
1463  return failure();
1464 
1465  // Recursively generates code if admissible.
1466  env.startEmit(options.sparseEmitStrategy);
1467  genBuffers(env, rewriter);
1468  // TODO: Constant affine expression should be handled differently when using
1469  // slice-based codegen, it does not matter now because we already reject the
1470  // constant expression at an earlier stage.
1471  genInitConstantDenseAddress(env, rewriter);
1472  genStmt(env, rewriter, env.getExprId(), 0);
1473  genResult(env, rewriter);
1474  return success();
1475  }
1476 
1477 private:
1478  /// Options to control sparse code generation.
1480 };
1481 
1482 } // namespace
1483 
1484 /// Populates the given patterns list with rewriting rules required for
1485 /// the sparsification of linear algebra operations.
1487  RewritePatternSet &patterns, const SparsificationOptions &options) {
1488  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1489 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)
Starts a single loop in current sequence.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)
Generates a case region in the coiterate operation.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
The code generation environment class aggregates a number of data structures that are needed during t...
Definition: CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
Definition: CodegenEnv.cpp:228
void updateValidLexInsert(Value val)
Definition: CodegenEnv.cpp:256
const SparsificationOptions & options() const
Definition: CodegenEnv.h:51
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
Definition: CodegenEnv.cpp:103
ArrayRef< LatPointId > set(LatSetId s) const
Definition: CodegenEnv.h:83
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
Definition: CodegenEnv.cpp:200
unsigned getCurrentDepth() const
Definition: CodegenEnv.h:115
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition: CodegenEnv.h:107
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition: CodegenEnv.h:95
const LatPoint & lat(LatPointId l) const
Definition: CodegenEnv.h:82
constexpr TensorId makeTensorId(unsigned t) const
Definition: CodegenEnv.h:72
void startExpand(Value values, Value filled, Value added, Value count)
Definition: CodegenEnv.cpp:205
unsigned getLoopNum() const
Definition: CodegenEnv.h:89
void updateInsertionChain(Value chain)
Definition: CodegenEnv.cpp:195
bool generatingSparseIterator() const
Definition: CodegenEnv.h:52
void startCustomReduc(ExprId exp)
Definition: CodegenEnv.cpp:266
linalg::GenericOp op() const
Definition: CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Definition: CodegenEnv.cpp:187
void startEmit(SparseEmitStrategy emitStrategy)
Definition: CodegenEnv.cpp:62
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition: CodegenEnv.h:111
const TensorExp & exp(ExprId e) const
Definition: CodegenEnv.h:81
void updateExpandCount(Value count)
Definition: CodegenEnv.cpp:214
bool isSparseOutput(OpOperand *o) const
Definition: CodegenEnv.h:133
void startValidLexInsert(Value val)
Definition: CodegenEnv.cpp:251
constexpr LoopId makeLoopId(unsigned i) const
Definition: CodegenEnv.h:75
LevelType lt(TensorId t, LoopId i) const
Definition: CodegenEnv.h:84
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:243
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
const std::vector< Value > & getValBuffer() const
Definition: LoopEmitter.h:241
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
Definition: LoopEmitter.h:175
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
Definition: LoopEmitter.h:227
auto getLoopIVsRange() const
Get the range of values for all induction variables.
Definition: LoopEmitter.h:161
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
Definition: LoopEmitter.h:238
A class to handle all iteration lattice operations.
Definition: Merger.h:225
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:416
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:940
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:249
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:426
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:442
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:431
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:472
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:348
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:545
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:363
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1618
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:481
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
bool isUniqueLT(LevelType lt)
Definition: Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
unsigned TensorLevel
Definition: LoopEmitter.h:26
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
Definition: CodegenUtils.h:400
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Definition: SparseTensor.h:179
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
Definition: Enums.h:412
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options for the Sparsification pass.
Definition: Passes.h:93
SparseEmitStrategy sparseEmitStrategy
Definition: Passes.h:107
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:106
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:218
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:215
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition: Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.