MLIR  18.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenUtils.h"
14 #include "LoopEmitter.h"
15 
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/Support/LLVM.h"
30 
31 using namespace mlir;
32 using namespace mlir::bufferization;
33 using namespace mlir::linalg;
34 using namespace mlir::sparse_tensor;
35 
36 //===---------------------------------------------------------------------===//
37 // Helper methods for the actual rewriting rules.
38 //===---------------------------------------------------------------------===//
39 
40 // Helper method to match any typed zero.
41 static bool isZeroValue(Value val) {
42  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
43 }
44 
45 // Helper to detect a sparse tensor type operand.
46 static bool isSparseTensor(Value v) {
47  auto enc = getSparseTensorEncoding(v.getType());
48  return enc && !llvm::all_of(enc.getLvlTypes(),
49  [](auto lt) { return lt == LevelType::Dense; });
50 }
51 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
52 
53 // Helper method to find zero/uninitialized tensor materialization.
54 static bool isMaterializing(OpOperand *op, bool isZero) {
55  Value val = op->get();
56  // Check allocation, with zero alloc when required.
57  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
58  Value copy = alloc.getCopy();
59  if (isZero)
60  return copy && isZeroValue(copy);
61  return !copy;
62  }
63  // Check for empty tensor materialization.
64  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
65  return !isZero;
66  // Last resort for zero alloc: the whole value is zero.
67  return isZero && isZeroValue(val);
68 }
69 
70 // Helper to detect sampling operation.
71 static bool isSampling(GenericOp op) {
72  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
73  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
74  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
75  // Both scalar input arguments used exactly once.
76  Value s1 = op.getBlock()->getArgument(0);
77  Value s2 = op.getBlock()->getArgument(1);
78  return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
79  (def->getOperand(1) == s1 && def->getOperand(0) == s2);
80  }
81  }
82  return false;
83 }
84 
85 // Helper to detect chain of multiplications that do not involve x.
86 static bool isMulChain(Value val, Value x) {
87  if (auto arg = dyn_cast<BlockArgument>(val))
88  return arg != x;
89  if (auto *def = val.getDefiningOp()) {
90  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
91  return isMulChain(def->getOperand(0), x) &&
92  isMulChain(def->getOperand(1), x);
93  }
94  return false;
95 }
96 
97 // Helper to detect x = x + <multiplications>.
98 static bool isSumOfMul(GenericOp op) {
99  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
100  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
101  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
102  Value x = op.getBlock()->getArguments().back();
103  return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
104  (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
105  }
106  }
107  return false;
108 }
109 
110 // Helper to detect direct yield of a zero value.
111 static bool isZeroYield(GenericOp op) {
112  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
113  if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
114  if (arg.getOwner()->getParentOp() == op) {
115  return isZeroValue(op->getOperand(arg.getArgNumber()));
116  }
117  }
118  return isZeroValue(yieldOp.getOperand(0));
119 }
120 
121 /// Populates given sizes array from type (for static sizes) and from
122 /// the tensor (for dynamic sizes).
123 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
124  Location loc, ShapedType stp, Value tensor) {
125  for (const auto &d : enumerate(stp.getShape())) {
126  Value dim;
127  if (d.value() == ShapedType::kDynamic)
128  dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
129  else
130  dim = constantIndex(builder, loc, d.value());
131  sizes.push_back(dim);
132  }
133 }
134 
135 static RankedTensorType getBufferType(const SparseTensorType &stt,
136  bool needTmpCOO) {
137  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
138  : stt.getRankedTensorType();
139 }
140 
141 /// Collects the dynamic dimension sizes for `tp` with the assumption that
142 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
143 /// sizes to dynSizes.
144 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
145  SmallVectorImpl<Value> &dynSizes) {
146  for (const auto &d : enumerate(tp.getShape())) {
147  if (d.value() == ShapedType::kDynamic)
148  dynSizes.push_back(sizes[d.index()]);
149  }
150 }
151 
153  RewriterBase &rewriter,
154  SparseElementsAttr attr) {
155  auto loc = op.getLoc();
156  SmallVector<Value> reduc = op.getInitArgs();
157 
158  // Foreach on constant.
160  rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
161  [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
162  SmallVector<Value> args;
163  args.append(cvs.begin(), cvs.end());
164  args.push_back(v);
165  args.append(reduc);
166  // Clones the foreach op to get a copy of the loop body.
167  auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
168  assert(args.size() == cloned.getBody()->getNumArguments());
169  Operation *yield = cloned.getBody()->getTerminator();
170  rewriter.inlineBlockBefore(cloned.getBody(), op, args);
171  // clean up
172  rewriter.eraseOp(cloned);
173  reduc = yield->getOperands();
174  rewriter.eraseOp(yield);
175  });
176 
177  rewriter.replaceOp(op, reduc);
178  return success();
179 }
180 
181 /// Populates the given sizes array for concatenation from types (for static
182 /// sizes) and from the source tensors (for dynamic sizes).
183 static void concatSizesFromInputs(OpBuilder &builder,
184  SmallVectorImpl<Value> &sizes, Location loc,
185  ShapedType dstTp, ValueRange srcs,
186  unsigned dim) {
187  auto dstShape = dstTp.getShape();
188  sizesFromSrc(builder, sizes, loc, srcs[0]);
189 
190  // Sum up on the `dim` if the dimension is dynamic.
191  if (dstShape[dim] != ShapedType::kDynamic) {
192  // Faithfully take the static size.
193  sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
194  } else {
195  // Else, compute the shape dynamically.
196  for (const auto &src : srcs.drop_front()) {
197  Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
198  // Sum up all the sizes.
199  sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
200  }
201  }
202 }
203 
204 //===---------------------------------------------------------------------===//
205 // The actual sparse tensor rewriting rules.
206 //===---------------------------------------------------------------------===//
207 
208 namespace {
209 
210 /// Rewriting rule that converts direct yield of zero with initial allocation.
211 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
212 public:
214 
215  LogicalResult matchAndRewrite(GenericOp op,
216  PatternRewriter &rewriter) const override {
217  if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
218  !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
219  !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
220  return failure();
221  auto outputType = getRankedTensorType(op.getResult(0));
222  // Yielding zero on newly materialized sparse tensor can be
223  // optimized directly (regardless of dynamic or static size).
224  if (getSparseTensorEncoding(outputType)) {
225  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
226  return success();
227  }
228  // Use static zero value directly instead of materialization.
229  if (!outputType.hasStaticShape())
230  return failure();
231  Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
232  rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
233  rewriter.eraseOp(def);
234  return success();
235  }
236 };
237 
238 /// Rewriting rule that converts two kernels:
239 ///
240 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
241 /// X(i,j) = S(i,j) * T(i,j)
242 ///
243 /// into a single kernel, using distributive law:
244 ///
245 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
246 ///
247 /// This kind of fusion (merging two ops into one but using arithmetic
248 /// equalities that may not hold for floating-point computations) would
249 /// be undesirable in the dense case, since we distribute the multiplication
250 /// into the reduction loop. However, for sparse sampling tensor S, such
251 /// a fusion may actually reduce the asymptotic complexity of the kernel,
252 /// since intermediate results may be nullified.
253 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
254 public:
256 
257  LogicalResult matchAndRewrite(GenericOp op,
258  PatternRewriter &rewriter) const override {
259  // Check consumer.
260  if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 2 ||
261  op.getNumResults() != 1 ||
262  op.getNumParallelLoops() != op.getNumLoops() ||
263  !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
264  !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
265  !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
266  return failure();
267  // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
268  // operand can be sparse or dense, since the point of this rewriting rule
269  // is detecting a situation in which *more* sparsity is introduced into
270  // a computation, be it already sparse or still dense.
271  unsigned other = 0;
272  if (isSparseTensor(op.getDpsInputOperand(0)))
273  other = 1;
274  else if (!isSparseTensor(op.getDpsInputOperand(1)))
275  return failure();
276  // Check producer.
277  auto prod = dyn_cast_or_null<GenericOp>(
278  op.getDpsInputOperand(other)->get().getDefiningOp());
279  if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
280  !prod.getResult(0).hasOneUse())
281  return failure();
282  // Sampling consumer and sum of multiplication chain producer.
283  if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
284  !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
285  !isSampling(op) || !isSumOfMul(prod))
286  return failure();
287  // Modify operand structure of producer and consumer.
288  Location loc = prod.getLoc();
289  SmallVector<Value> inputOps = prod.getInputs();
290  SmallVector<Value> outputOps = op.getOutputs();
291  SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
292  inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
293  fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
294  // Fuse producer and consumer into a new generic op.
295  auto fusedOp = rewriter.create<GenericOp>(
296  loc, op.getResult(0).getType(), inputOps, outputOps,
297  rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
298  /*doc=*/nullptr, /*library_call=*/nullptr);
299  Block &prodBlock = prod.getRegion().front();
300  Block &consBlock = op.getRegion().front();
301  IRMapping mapper;
302  Block *fusedBlock = new Block();
303  fusedOp.getRegion().push_back(fusedBlock);
304  unsigned num = prodBlock.getNumArguments();
305  for (unsigned i = 0; i < num - 1; i++)
306  addArg(mapper, fusedBlock, prodBlock.getArgument(i));
307  addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
308  addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
309  // Clone bodies of the producer and consumer in new evaluation order.
310  auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
311  auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
312  rewriter.setInsertionPointToStart(fusedBlock);
313  Value last;
314  for (auto &op : prodBlock.without_terminator())
315  if (&op != acc) {
316  last = op.getResult(0);
317  rewriter.clone(op, mapper);
318  }
319  mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
320  mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
321  last = rewriter.clone(*acc, mapper)->getResult(0);
322  rewriter.create<linalg::YieldOp>(loc, last);
323  // Force initial value on merged allocation for dense outputs.
324  // TODO: deal with non alloc tensor here one day
325  if (!getSparseTensorEncoding(op.getResult(0).getType())) {
326  Value init = prod.getDpsInitOperand(0)
327  ->get()
328  .getDefiningOp<AllocTensorOp>()
329  .getCopy();
330  AllocTensorOp a =
331  op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
332  rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
333  }
334  // Replace consumer with fused operation. Old producer
335  // and consumer ops will be removed by DCE.
336  rewriter.replaceOp(op, fusedOp->getResults());
337  return success();
338  }
339 
340 private:
341  // Helper to add argument and record the mapping.
342  static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
343  mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
344  }
345 };
346 
347 // Fuse a tensor cast into producing operation. Note that a tensor.cast
348 // should really not be used to convert between sparse encodings. Since
349 // the pattern currently appears as a result of some prior rewriting
350 // we make an attempt to repair very obvious cases.
351 // TODO: audit the pure tensor dialect rewriting rules
352 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
353 public:
355 
356  LogicalResult matchAndRewrite(tensor::CastOp op,
357  PatternRewriter &rewriter) const override {
358  Type srcType = op.getSource().getType();
359  Type dstType = op.getDest().getType();
360  // A nop cast simply folds away.
361  if (srcType == dstType) {
362  rewriter.replaceOp(op, op->getResults());
363  return success();
364  }
365  // See if a sparsity changing cast can be fused into producer.
366  if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
367  if (Operation *def = op.getSource().getDefiningOp()) {
368  if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
369  rewriter.updateRootInPlace(def, [&]() {
370  def->getResult(0).setType(op->getResultTypes()[0]);
371  });
372  rewriter.replaceOp(op, def->getResult(0));
373  return success();
374  }
375  }
376  }
377  // Repair tensor casts with at least one sparse operand into the
378  // the properly supported sparse_tensor.convert.
379  if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
380  rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
381  return success();
382  }
383  // Fail otherwise.
384  return failure();
385  }
386 };
387 
388 /// Rewrites a sequence of operations for sparse tensor selections in to
389 /// semi-ring operations such that they can be compiled correctly by the
390 /// sparsifier. E.g., transforming the following sequence
391 ///
392 /// %sel = arith.select %cond, %sp1, %sp2
393 ///
394 /// to
395 ///
396 /// %sel = binary %sp1, %sp2:
397 /// both (%l, %r) {yield select %cond, %l, %r}
398 /// left (%l) {yield select %cond, %l, 0}
399 /// right (%r) {yield select %cond, 0, %r}
400 ///
401 /// TODO: We require that the tensor used for extracting conditions to be dense
402 /// to sparsify the code. To support a sparse condition tensor, we need a
403 /// tri-nary operation.
404 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
405 public:
407  LogicalResult matchAndRewrite(GenericOp op,
408  PatternRewriter &rewriter) const override {
409  // Rejects non sparse kernels.
410  if (!op.hasTensorSemantics() || !hasAnySparseOperand(op))
411  return failure();
412 
413  Location loc = op.getLoc();
415  for (Operation &inst : *op.getBody()) {
416  // Matches pattern.
417  auto matched = isRewritablePattern(op, &inst);
418  if (!matched.has_value())
419  continue;
420 
421  rewriter.setInsertionPoint(&inst);
422  auto [c, t, f] = matched.value();
423  assert(t.getType() == f.getType());
424  auto selTp = t.getType();
425  auto c0 = constantZero(rewriter, loc, selTp);
426  auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
427  // Initializes all the blocks.
428  rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
429  {t.getLoc(), f.getLoc()});
430  rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
431  rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
432 
433  for (auto *r : binOp.getRegions()) {
434  Block *b = &r->front();
435  rewriter.setInsertionPointToStart(b);
436 
437  IRMapping irMap;
438  // Clones the cmp operations into the region to make the binary op
439  // admissible.
440  Value newC = c;
441  if (auto *def = c.getDefiningOp())
442  newC = rewriter.clone(*def, irMap)->getResult(0);
443 
444  irMap.map(c, newC);
445  if (r == &binOp.getLeftRegion()) {
446  irMap.map(t, b->getArgument(0));
447  irMap.map(f, c0);
448  } else if (r == &binOp.getRightRegion()) {
449  irMap.map(t, c0);
450  irMap.map(f, b->getArgument(0));
451  } else {
452  irMap.map(t, b->getArgument(0));
453  irMap.map(f, b->getArgument(1));
454  }
455  auto y = rewriter.clone(inst, irMap)->getResult(0);
456  rewriter.create<sparse_tensor::YieldOp>(loc, y);
457  }
458 
459  // We successfully rewrited a operation. We can not do replacement here
460  // becuase it invalidate the iterator for the current loop to traverse
461  // the instructions.
462  semiRings.emplace_back(&inst, binOp);
463  }
464 
465  // Finalizes the replacement.
466  for (auto [sel, semi] : semiRings)
467  rewriter.replaceOp(sel, semi->getResults());
468 
469  return success(!semiRings.empty());
470  }
471 
472 private:
473  static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
474  isRewritablePattern(GenericOp op, Operation *v) {
475  auto sel = dyn_cast<arith::SelectOp>(v);
476  if (!sel)
477  return std::nullopt;
478 
479  auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
480  auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
481  // TODO: For simplicity, we only handle cases where both true/false value
482  // are directly loaded the input tensor. We can probably admit more cases
483  // in theory.
484  if (!tVal || !fVal)
485  return std::nullopt;
486 
487  // Helper lambda to determine whether the value is loaded from a dense input
488  // or is a loop invariant.
489  auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
490  if (auto bArg = v.dyn_cast<BlockArgument>();
491  bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
492  return true;
493  // If the value is defined outside the loop, it is a loop invariant.
494  return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
495  };
496 
497  // If the condition value is load directly from a dense tensor or
498  // loop-invariants, we can sparsify the kernel.
499  auto cond = sel.getCondition();
500  if (isValFromDenseInputOrInvariant(cond))
501  return std::make_tuple(cond, tVal, fVal);
502 
503  Value cmpL, cmpR;
504  if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
505  matchers::m_Any(&cmpR))) ||
506  matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
507  matchers::m_Any(&cmpR)))) {
508  // TODO: we can do it recursively to check whether all the leaf values are
509  // loaded from dense tensors or are loop invariants.
510  if (isValFromDenseInputOrInvariant(cmpL) ||
511  isValFromDenseInputOrInvariant(cmpR))
512  return std::make_tuple(cond, tVal, fVal);
513  }
514 
515  return std::nullopt;
516  };
517 };
518 
519 /// Rewrites a sparse reduction that would not sparsify directly since
520 /// doing so would only iterate over the stored elements, ignoring the
521 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
522 /// (note that reductions like add/sub/or/xor can directly be sparsified
523 /// since the implicit zeros do not contribute to the final result).
524 /// Note that prod/and are still included since, even though they often
525 /// are nullified in sparse data, they may still occur for special
526 /// situations in which e.g. some rows in a sparse matrix are fully
527 /// dense. For min/max, including the implicit zeros is a much more
528 /// common situation.
529 ///
530 /// TODO: this essentially "densifies" the operation; we want to implement
531 /// this much more efficiently by performing the reduction over the
532 /// stored values, and feed in the zero once if there were *any*
533 /// implicit zeros as well; but for now, at least we provide
534 /// the functionality
535 ///
536 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
537 public:
539 
540  LogicalResult matchAndRewrite(GenericOp op,
541  PatternRewriter &rewriter) const override {
542  // Reject non-reductions.
543  if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 1 ||
544  op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
545  return failure();
546  auto inp = op.getDpsInputOperand(0);
547  auto init = op.getDpsInitOperand(0);
548  if (!isSparseTensor(inp))
549  return failure();
550  // Look for direct x = x OP y for semi-ring ready reductions.
551  auto red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
552  .getOperand(0)
553  .getDefiningOp();
554  if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
555  arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
556  arith::MaxUIOp>(red))
557  return failure();
558  Value s0 = op.getBlock()->getArgument(0);
559  Value s1 = op.getBlock()->getArgument(1);
560  if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
561  (red->getOperand(0) != s1 || red->getOperand(1) != s0))
562  return failure();
563  // Identity.
564  Location loc = op.getLoc();
565  Value identity =
566  rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
567  // Unary {
568  // present -> value
569  // absent -> zero.
570  // }
571  Type rtp = s0.getType();
572  rewriter.setInsertionPointToStart(&op.getRegion().front());
573  auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
574  Block *present =
575  rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
576  rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
577  rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
578  rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
579  rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
580  auto zero =
581  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
582  rewriter.create<sparse_tensor::YieldOp>(loc, zero);
583  rewriter.setInsertionPointAfter(semiring);
584  // CustomReduce {
585  // x = x REDUC y, identity
586  // }
587  auto custom = rewriter.create<sparse_tensor::ReduceOp>(
588  loc, rtp, semiring.getResult(), s1, identity);
589  Block *region =
590  rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
591  rewriter.setInsertionPointToStart(&custom.getRegion().front());
592  IRMapping irMap;
593  irMap.map(red->getOperand(0), region->getArgument(0));
594  irMap.map(red->getOperand(1), region->getArgument(1));
595  auto cloned = rewriter.clone(*red, irMap);
596  rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
597  rewriter.setInsertionPointAfter(custom);
598  rewriter.replaceOp(red, custom.getResult());
599  return success();
600  }
601 };
602 
603 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
604 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
605 public:
607 
608  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
609  PatternRewriter &rewriter) const override {
610  Location loc = op.getLoc();
611  Value srcTensor = op.getSource();
612  const auto srcTp = getSparseTensorType(srcTensor);
613  const auto dstTp = getSparseTensorType(op.getResult());
614 
615  if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
616  !dstTp.hasStaticDimShape())
617  return failure();
618 
619  SmallVector<Value> srcSizes;
620  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
621  SmallVector<Value> dstSizes;
622  for (Dimension d : dstTp.getDimShape())
623  dstSizes.push_back(constantIndex(rewriter, loc, d));
624 
625  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
626  // Only need an unordered COO buffer if input and output are not sorted
627  // in the same way.
628  Type bufferTp = getBufferType(
629  dstTp.withoutDimToLvl(),
630  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
631  SmallVector<Value> dynSizes;
632  Value buffer = rewriter
633  .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
634  nnz, Attribute())
635  .getResult();
636 
637  // Convert src coordinates to dst coordinates by first collapsing it to 1D
638  // and then expand it to the match the rank of the destination tensor.
639  // Implemented as follows:
640  // foreach srcCoords %srcTensor
641  // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
642  // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
643  // insert expandedCoords, %buffer
644  //
645  // followed by an optional
646  // %t = sparse_tensor.cast %tmp
647  // depending on whether the input/output are sorted in the same way.
648  const auto encSrc = srcTp.getEncoding();
649  ForeachOp foreachOp = rewriter.create<ForeachOp>(
650  loc, srcTensor, buffer,
651  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
652  ValueRange reduc) {
653  const Dimension srcRank = srcTp.getDimRank();
654  SmallVector<Value> srcDcvs;
655  srcDcvs.reserve(srcRank);
656  for (Dimension d = 0; d < srcRank; d++) {
657  Level lvl = toLvl(encSrc, d);
658  srcDcvs.push_back(srcLcvs[lvl]);
659  }
660 
661  Value collapseSize = constantIndex(builder, loc, 1);
662  for (Dimension d = 0; d < srcRank; d++)
663  collapseSize =
664  builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
665  SmallVector<Value, 1> collapsedSizes = {collapseSize};
666 
667  ReassociationIndices collapseIdx;
668  for (Dimension i = 0; i < srcRank; i++)
669  collapseIdx.push_back(i);
670  SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
671  SmallVector<Value, 1> collapsedDcvs;
672  reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
673  collapsedSizes, collapsedDcvs);
674 
675  ReassociationIndices expandIdx;
676  for (Dimension i = 0; i < dstTp.getDimRank(); i++)
677  expandIdx.push_back(i);
678  SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
679  SmallVector<Value> dstDcvs;
680  reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
681  dstSizes, dstDcvs);
682 
683  auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
684  builder.create<sparse_tensor::YieldOp>(loc, t);
685  });
686 
687  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
688  if (bufferTp != dstTp) {
689  auto dstRTT = dstTp.getRankedTensorType();
690  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
691  rewriter.create<DeallocTensorOp>(loc, t);
692  t = converted;
693  }
694  rewriter.replaceOp(op, t);
695  return success();
696  }
697 };
698 
699 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
700 template <typename ReshapeOp>
701 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
702 public:
704 
705  LogicalResult matchAndRewrite(ReshapeOp op,
706  PatternRewriter &rewriter) const override {
707  Location loc = op.getLoc();
708  Value srcTensor = op.getSrc();
709  const auto srcTp = getSparseTensorType(srcTensor);
710  const auto dstTp = getSparseTensorType(op.getResult());
711  if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
712  return failure();
713 
714  // Generate code to represent the static dimension constants or compute
715  // the dynamic dimension values.
716  SmallVector<Value> srcSizes;
717  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
718  SmallVector<Value> dstSizes;
719  SmallVector<Value> dstDynSizes;
720  if (dstTp.hasStaticDimShape()) {
721  for (Dimension d : dstTp.getDimShape())
722  dstSizes.push_back(constantIndex(rewriter, loc, d));
723  } else {
724  ArrayRef<Size> dstShape = dstTp.getDimShape();
725  genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
726  op.getReassociationIndices());
727  for (auto [idx, shape] : llvm::enumerate(dstShape)) {
728  if (shape == ShapedType::kDynamic)
729  dstDynSizes.push_back(dstSizes[idx]);
730  }
731  }
732  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
733  // Only need a unordered COO buffer if input and output are not sorted
734  // in the same way.
735  Type bufferTp = getBufferType(
736  dstTp.withoutDimToLvl(),
737  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
738 
739  Value buffer =
740  rewriter
741  .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
742  /*sizeHint=*/nnz, Attribute())
743  .getResult();
744 
745  // Implement the sparse2sparse reshape as follows:
746  // foreach srcCoords %srcTensor
747  // insert reshapeCvs(srcCoords), %buffer
748  //
749  // followed by an optional
750  // %t = sparse_tensor.cast %tmp
751  // depending on whether the input/output are sorted in the same way.
752  const auto encSrc = srcTp.getEncoding();
753  ForeachOp foreachOp = rewriter.create<ForeachOp>(
754  loc, srcTensor, buffer,
755  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
756  ValueRange reduc) {
757  const Dimension dimRank = srcTp.getDimRank();
758  SmallVector<Value> srcDcvs;
759  srcDcvs.reserve(dimRank);
760  for (Dimension d = 0; d < dimRank; d++) {
761  Level lvl = toLvl(encSrc, d);
762  srcDcvs.push_back(srcLcvs[lvl]);
763  }
764  SmallVector<Value> dstDcvs;
765  reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
766  srcDcvs, dstSizes, dstDcvs);
767  auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
768  builder.create<sparse_tensor::YieldOp>(loc, t);
769  });
770 
771  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
772  if (bufferTp != dstTp) {
773  auto dstRTT = dstTp.getRankedTensorType();
774  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
775  rewriter.create<DeallocTensorOp>(loc, t);
776  t = converted;
777  }
778  rewriter.replaceOp(op, t);
779  return success();
780  }
781 };
782 
783 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
784 /// operator.
785 template <typename ReshapeOp>
786 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
787 public:
789 
790  LogicalResult matchAndRewrite(ReshapeOp op,
791  PatternRewriter &rewriter) const override {
792  Location loc = op->getLoc();
793  auto encDst = getSparseTensorEncoding(op.getResult().getType());
794  auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
795  // Since a pure dense expansion is very cheap (change of view), for
796  // a sparse2dense or dense2sparse, we can simply unfuse a sparse
797  // conversion from the reshape operation itself.
798  // All other cases are handled elsewhere.
799  if (encDst && encSrc) {
800  return failure();
801  }
802  if (encSrc) {
803  auto rtp = getRankedTensorType(op.getSrc());
804  auto denseTp =
805  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
806  auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
807  rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
808  return success();
809  }
810  if (encDst) {
811  auto rtp = getRankedTensorType(op.getResult());
812  auto denseTp =
813  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
814  auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
815  op.getReassociation());
816  Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
817  rewriter.replaceOp(op, convert);
818  return success();
819  }
820  return failure();
821  }
822 };
823 
824 // A trivial wrapper to help generate different operations for dense/sparse
825 // tensors.
826 struct TensorLike {
827  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
828  ValueRange sizes) {
829  SmallVector<Value> dynSzs;
830  getDynamicSizes(rtt, sizes, dynSzs);
831 
832  val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
833  if (!isSparse()) {
834  Value c0 = constantZero(builder, loc, rtt.getElementType());
835  val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
836  }
837  }
838 
839  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
840  val = builder.create<tensor::InsertOp>(loc, v, val, crds);
841  }
842 
843  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
844  if (isSparse())
845  return builder.create<LoadOp>(loc, val, true);
846  return val;
847  }
848 
849  bool isSparse() const {
850  return getSparseTensorEncoding(val.getType()) != nullptr;
851  }
852 
853  Value val;
854 };
855 
856 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
858  LogicalResult matchAndRewrite(tensor::DimOp op,
859  PatternRewriter &rewriter) const override {
860  std::optional<int64_t> dim = op.getConstantIndex();
861  auto stt = getSparseTensorType(op.getSource());
862  if (!dim || !stt.hasEncoding())
863  return failure();
864 
865  if (stt.isPermutation()) {
866  rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
867  toLvl(stt.getEncoding(), *dim));
868  return success();
869  }
870 
871  // Non-permutation dim2lvl/lvl2dim maps.
872  // Compute as follows:
873  // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
874  // Note that it is not the most efficient way (but a more general one) for
875  // the lvl to dim translation, e.g., for BSR, the dimension size for can be
876  // computed simply by lvl_size * block_size.
877  Location loc = op.getLoc();
878  SmallVector<Value> maxLvlCrds;
879  for (Level l = 0; l < stt.getLvlRank(); l++) {
880  Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
881  Value maxLvlCrd = rewriter.create<arith::SubIOp>(
882  loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
883  maxLvlCrds.push_back(maxLvlCrd);
884  }
885 
886  AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
887  Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
888  op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
889  maxLvlCrds);
890 
891  Value dimSz = rewriter.create<arith::AddIOp>(
892  loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
893  rewriter.replaceOp(op, dimSz);
894  return success();
895  }
896 };
897 
898 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
900  LogicalResult matchAndRewrite(ConcatenateOp op,
901  PatternRewriter &rewriter) const override {
902  if (op.needsExtraSort())
903  op.emitError("ConcatenateOp not staged");
904 
905  const Location loc = op.getLoc();
906  const auto dstTp = getSparseTensorType(op);
907  const Dimension conDim = op.getDimension();
908  SmallVector<Value> sizes;
909  concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
910 
911  // %t = concatenate %s1, %s2, %s3 {dim = 1}
912  // ==>
913  // if (isSparseDst)
914  // if (allDense)
915  // %tmp = bufferization.alloc_tensor dstTp
916  // else
917  // %tmp = bufferization.alloc_tensor : unordered COO
918  // else
919  // %tmp = memref.alloc : dense tensor
920  // foreach in %s1 : insert d0, d1, %tmp
921  // foreach in %s2 : insert d0, d1 + size(s1), %tmp
922  // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
923 
924  TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
925  Value offset = constantIndex(rewriter, loc, 0);
926  Value iterArg = dstBuf.val;
927 
928  ForeachOp foreachOp;
929  for (Value input : op.getInputs()) {
930  // Builds a for op for each input tensor to append new values into the
931  // output tensor.
932  foreachOp = rewriter.create<ForeachOp>(
933  loc, input, iterArg,
934  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
935  ValueRange reduc) {
936  SmallVector<Value> offDimCrd(dcvs);
937  offDimCrd[conDim] =
938  builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
939 
940  // Enters foreach, updates the SSA chain.
941  dstBuf.val = reduc.front();
942  if (!dstTp.isAllDense()) {
943  Value cond = genIsNonzero(builder, loc, v);
944  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
945  /*else*/ true);
946  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
947  builder.create<scf::YieldOp>(loc, dstBuf.val);
948 
949  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
950  dstBuf.insert(builder, loc, v, offDimCrd);
951  builder.create<scf::YieldOp>(loc, dstBuf.val);
952 
953  // Exits the ifOp, update the sparse tensor SSA value.
954  builder.setInsertionPointAfter(ifOp);
955  dstBuf.val = ifOp.getResult(0);
956  } else {
957  dstBuf.insert(builder, loc, v, offDimCrd);
958  }
959  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
960  });
961  // Accumulates the offset. Note that only static-shaped inputs are allowed
962  // by concatenate op verifier, which saves us from computing the offset
963  // dynamically.
964  const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
965  assert(!ShapedType::isDynamic(sz));
966  offset = rewriter.create<arith::AddIOp>(loc, offset,
967  constantIndex(rewriter, loc, sz));
968  iterArg = foreachOp.getResult(0);
969  dstBuf.val = iterArg;
970  }
971 
972  dstBuf.val = iterArg;
973  Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
974  rewriter.replaceOp(op, ret);
975  return success();
976  }
977 };
978 
979 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
981  LogicalResult matchAndRewrite(ConvertOp op,
982  PatternRewriter &rewriter) const override {
983  if (op.needsExtraSort())
984  return op.emitError("ConvertOp not staged.");
985 
986  // TODO: Maybe we want a different operation for this too.
987  auto encDst = getSparseTensorEncoding(op.getType());
988  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
989  if (encDst && encSrc && !encSrc.isSlice() &&
990  encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
991  // Trivial tensor conversion and simple element type conversion is handled
992  // in codegen.
993  return failure();
994  }
995 
996  Location loc = op.getLoc();
997  Value src = op.getSource();
998 
999  SparseTensorType srcStt = getSparseTensorType(op.getSource());
1000  SparseTensorType dstStt = getSparseTensorType(op.getDest());
1001 
1002  bool fromSparseConst = false;
1003  if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1004  if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1005  fromSparseConst = true;
1006 
1007  const AffineMapAttr foreachOrder =
1008  (!dstStt.isIdentity() && fromSparseConst)
1010  : nullptr;
1011 
1012  bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1013 
1014  SmallVector<Value> sizes;
1015  sizesFromSrc(rewriter, sizes, loc, src);
1016  ValueRange vs;
1017  TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1018 
1019  auto foreachOp = rewriter.create<ForeachOp>(
1020  loc, src, dstBuf.val, foreachOrder,
1021  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1022  ValueRange reduc) {
1023  // Enters the loop, update the SSA value for insertion chain.
1024  dstBuf.val = reduc.front();
1025  if (!skipZeroCheck) {
1026  Value cond = genIsNonzero(builder, loc, v);
1027  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1028  /*else*/ true);
1029  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1030  builder.create<scf::YieldOp>(loc, dstBuf.val);
1031 
1032  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1033  dstBuf.insert(builder, loc, v, dcvs);
1034  builder.create<scf::YieldOp>(loc, dstBuf.val);
1035 
1036  // Exits the ifOp, update the sparse tensor SSA value.
1037  builder.setInsertionPointAfter(ifOp);
1038  dstBuf.val = ifOp.getResult(0);
1039  } else {
1040  dstBuf.insert(builder, loc, v, dcvs);
1041  }
1042  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1043  });
1044 
1045  rewriter.setInsertionPointAfter(foreachOp);
1046 
1047  // Exits the for loop, links the SSA chain.
1048  dstBuf.val = foreachOp.getResult(0);
1049 
1050  Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1051  rewriter.replaceOp(op, ret);
1052  return success();
1053  }
1054 };
1055 
1056 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1058  LogicalResult matchAndRewrite(CrdTranslateOp op,
1059  PatternRewriter &rewriter) const override {
1060  AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1061  ? op.getEncoder().getDimToLvl()
1062  : op.getEncoder().getLvlToDim();
1063 
1064  SmallVector<Value> outCrds;
1065  for (AffineExpr result : map.getResults()) {
1066  // TODO: we should probably expand the affine map to IR using our own
1067  // rules, since affine.apply assume signed value, while the cooridinates
1068  // we provided must always be signless.
1069  Value trans = rewriter.create<affine::AffineApplyOp>(
1070  op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1071  op.getInCrds());
1072  outCrds.push_back(trans);
1073  }
1074  rewriter.replaceOp(op, outCrds);
1075  return success();
1076  }
1077 };
1078 
1079 /// Sparse rewriting rule for the foreach operator.
1080 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1081 public:
1083 
1084  LogicalResult matchAndRewrite(ForeachOp op,
1085  PatternRewriter &rewriter) const override {
1086 
1087  auto loc = op.getLoc();
1088  Value input = op.getTensor();
1089  SmallVector<Value> reduc = op.getInitArgs();
1090  const auto stt = getSparseTensorType(input);
1091  const Level lvlRank = stt.getLvlRank();
1092 
1093  // Special-case: for each over a sparse constant uses its own rewriting
1094  // rule.
1095  if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1096  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1097  return genForeachOnSparseConstant(op, rewriter, attr);
1098  }
1099  }
1100 
1101  // Otherwise, use loop emitter to generate loops.
1102  const auto enc = stt.getEncoding();
1103 
1104  // 1. Generates loop for the sparse input.
1105  LoopEmitter loopEmitter(
1106  ValueRange{input},
1107  StringAttr::get(getContext(), ForeachOp::getOperationName()));
1108  loopEmitter.initializeLoopEmit(rewriter, loc);
1109  for (Level l = 0; l < lvlRank; l++) {
1110  // TODO: provide utility function for loop sequences that only contains
1111  // one for loop?
1112  const SmallVector<TensorLevel, 1> tidLvls{
1113  loopEmitter.makeTensorLevel(0, l)};
1114  loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1115  // Note that reduc will be taken care of by loop emitter and get updated
1116  // in place.
1117  loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1118  reduc);
1119  }
1120 
1121  SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1122  if (op.getOrder()) {
1123  // TODO: Support it so that we can do direct conversion from CSR->BSR.
1124  llvm_unreachable(
1125  "Level order not yet implemented on non-constant input tensors.");
1126  }
1127 
1128  Value vals = loopEmitter.getValBuffer()[0];
1129  Value pos = loopEmitter.getPosits()[0].back();
1130  // Loads the value from sparse tensor using position-index;
1131  // loads the value from dense tensor using coords.
1132  Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1133  : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1134 
1135  // 2. Inline the block in the foreach operator.
1136  Block *srcBlock = op.getBody();
1137 
1138  // Remap coordinates.
1139  SmallVector<Value> args =
1140  enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1141 
1142  // Remap value.
1143  args.push_back(val);
1144  // Remap reduction variables.
1145  args.append(reduc);
1146 
1147  // Remove sparse_tensor.yield.
1148  SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1149  rewriter.eraseOp(srcBlock->getTerminator());
1150 
1151  // Inline body.
1152  if (!reducValue.empty()) {
1153  rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
1154  } else {
1155  // This is annoying, since scf.for inserts a implicit yield op when
1156  // there is no reduction variable upon creation, in this case we need to
1157  // merge the block *before* the yield op.
1158  rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(),
1159  args);
1160  }
1161 
1162  for (Level l = 0; l < lvlRank; l++) {
1163  // Link the reduction chain. Note that loop emitter update the reducValue
1164  // in place.
1165  loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1166  loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1167  }
1168 
1169  // Replace the foreach operator with the value returned by the outtermost
1170  // for loop.
1171  rewriter.replaceOp(op, reducValue);
1172  return success();
1173  }
1174 };
1175 
1176 /// Sparse rewriting rule for the new operator.
1177 struct NewRewriter : public OpRewritePattern<NewOp> {
1179  LogicalResult matchAndRewrite(NewOp op,
1180  PatternRewriter &rewriter) const override {
1181  Location loc = op.getLoc();
1182  auto stt = getSparseTensorType(op.getResult());
1183  if (!stt.hasEncoding() || stt.getCOOStart() == 0)
1184  return failure();
1185 
1186  // Implement the NewOp as follows:
1187  // %orderedCoo = sparse_tensor.new %filename
1188  // %t = sparse_tensor.convert %orderedCoo
1189  // with enveloping reinterpreted_map ops for non-permutations.
1190  RankedTensorType dstTp = stt.getRankedTensorType();
1191  RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1192  Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1193  Value convert = cooTensor;
1194  auto enc = stt.getEncoding();
1195  if (!stt.isPermutation()) { // demap coo, demap dstTp
1196  auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1197  convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1198  dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1199  }
1200  convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1201  if (!stt.isPermutation()) // remap to original enc
1202  convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1203  rewriter.replaceOp(op, convert);
1204 
1205  // Release the temporary ordered COO tensor.
1206  rewriter.setInsertionPointAfterValue(convert);
1207  rewriter.create<DeallocTensorOp>(loc, cooTensor);
1208 
1209  return success();
1210  }
1211 };
1212 
1213 /// Sparse rewriting rule for the out operator.
1214 struct OutRewriter : public OpRewritePattern<OutOp> {
1216  LogicalResult matchAndRewrite(OutOp op,
1217  PatternRewriter &rewriter) const override {
1218  Location loc = op.getLoc();
1219  // Calculate NNZ.
1220  Value src = op.getTensor();
1221  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1222 
1223  // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1224  const auto srcTp = getSparseTensorType(src);
1225  const Dimension dimRank = srcTp.getDimRank();
1226  Type indexTp = rewriter.getIndexType();
1227  Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1228 
1229  // Generate code to calculate dimension size values and store the values to
1230  // the buffer.
1231  SmallVector<Value> dims;
1232  sizesForTensor(rewriter, dims, loc, srcTp, src);
1233  for (Dimension d = 0; d < dimRank; d++) {
1234  rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1235  constantIndex(rewriter, loc, d));
1236  }
1237 
1238  // Create a sparse tensor writer and output meta data.
1239  Type opaqueTp = getOpaquePointerType(rewriter);
1240  Value writer =
1241  createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1242  {op.getDest()}, EmitCInterface::Off)
1243  .getResult(0);
1244  Value rankValue = constantIndex(rewriter, loc, dimRank);
1245  createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1246  {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1247 
1248  Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1249  Type eltTp = srcTp.getElementType();
1250  SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1251  primaryTypeFunctionSuffix(eltTp)};
1252  Value value = genAllocaScalar(rewriter, loc, eltTp);
1253  ModuleOp module = op->getParentOfType<ModuleOp>();
1254 
1255  // For each element in the source tensor, output the element.
1256  rewriter.create<ForeachOp>(
1257  loc, src, std::nullopt,
1258  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1259  ValueRange reduc) {
1260  for (Dimension d = 0; d < dimRank; d++) {
1261  rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1262  constantIndex(builder, loc, d));
1263  }
1264  rewriter.create<memref::StoreOp>(loc, v, value);
1265  SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1266  FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1267  EmitCInterface::On);
1268  builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1269  builder.create<sparse_tensor::YieldOp>(loc);
1270  });
1271 
1272  // Release the writer.
1273  createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1274  EmitCInterface::Off);
1275 
1276  rewriter.eraseOp(op);
1277  return success();
1278  }
1279 };
1280 
1281 } // namespace
1282 
1283 //===---------------------------------------------------------------------===//
1284 // Methods that add patterns described in this file to a pattern list.
1285 //===---------------------------------------------------------------------===//
1286 
1288  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1289  GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
1290 }
1291 
1293  bool enableRT,
1294  bool enableConvert) {
1295  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1296  ReshapeRewriter<tensor::CollapseShapeOp>,
1297  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1298  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1299  SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1300  patterns.getContext());
1301 
1302  if (enableConvert)
1303  patterns.add<DirectConvertRewriter>(patterns.getContext());
1304  if (!enableRT)
1305  patterns.add<NewRewriter>(patterns.getContext());
1306 }
1307 
1309  // Run CrdTranslateRewriter later in the pipeline so that operation can be
1310  // folded before lowering to affine.apply
1311  patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1312 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:374
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Location getLoc() const
Return the location for this argument.
Definition: Value.h:330
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Operation & back()
Definition: Block.h:145
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:433
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
U dyn_cast() const
Definition: Value.h:106
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:88
auto m_Any()
Definition: Matchers.h:448
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:339
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:350
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:74
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:93
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:118
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:340
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361