MLIR  20.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
15 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
32 
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
37 
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
41 
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45 }
46 
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49  auto enc = getSparseTensorEncoding(v.getType());
50  return enc && !llvm::all_of(enc.getLvlTypes(),
51  [](auto lt) { return lt == LevelFormat::Dense; });
52 }
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54 
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57  Value val = op->get();
58  // Check allocation, with zero alloc when required.
59  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60  Value copy = alloc.getCopy();
61  if (isZero)
62  return copy && isZeroValue(copy);
63  return !copy;
64  }
65  // Check for empty tensor materialization.
66  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67  return !isZero;
68  // Last resort for zero alloc: the whole value is zero.
69  return isZero && isZeroValue(val);
70 }
71 
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77  // Both scalar input arguments used exactly once.
78  Value s1 = op.getBlock()->getArgument(0);
79  Value s2 = op.getBlock()->getArgument(1);
80  return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81  (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82  }
83  }
84  return false;
85 }
86 
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89  if (auto arg = dyn_cast<BlockArgument>(val))
90  return arg != x;
91  if (auto *def = val.getDefiningOp()) {
92  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93  return isMulChain(def->getOperand(0), x) &&
94  isMulChain(def->getOperand(1), x);
95  }
96  return false;
97 }
98 
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104  Value x = op.getBlock()->getArguments().back();
105  return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106  (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107  }
108  }
109  return false;
110 }
111 
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115  if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116  if (arg.getOwner()->getParentOp() == op) {
117  return isZeroValue(op->getOperand(arg.getArgNumber()));
118  }
119  }
120  return isZeroValue(yieldOp.getOperand(0));
121 }
122 
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126  Location loc, ShapedType stp, Value tensor) {
127  for (const auto &d : enumerate(stp.getShape())) {
128  Value dim;
129  if (d.value() == ShapedType::kDynamic)
130  dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131  else
132  dim = constantIndex(builder, loc, d.value());
133  sizes.push_back(dim);
134  }
135 }
136 
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138  bool needTmpCOO) {
139  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140  : stt.getRankedTensorType();
141 }
142 
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147  SmallVectorImpl<Value> &dynSizes) {
148  for (const auto &d : enumerate(tp.getShape())) {
149  if (d.value() == ShapedType::kDynamic)
150  dynSizes.push_back(sizes[d.index()]);
151  }
152 }
153 
154 static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155  RewriterBase &rewriter,
156  SparseElementsAttr attr) {
157  auto loc = op.getLoc();
158  SmallVector<Value> reduc = op.getInitArgs();
159 
160  // Foreach on constant.
162  rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163  [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164  SmallVector<Value> args;
165  args.append(cvs.begin(), cvs.end());
166  args.push_back(v);
167  args.append(reduc);
168  // Clones the foreach op to get a copy of the loop body.
169  auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170  assert(args.size() == cloned.getBody()->getNumArguments());
171  Operation *yield = cloned.getBody()->getTerminator();
172  rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173  // clean up
174  rewriter.eraseOp(cloned);
175  reduc = yield->getOperands();
176  rewriter.eraseOp(yield);
177  });
178 
179  rewriter.replaceOp(op, reduc);
180  return success();
181 }
182 
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186  SmallVectorImpl<Value> &sizes, Location loc,
187  ShapedType dstTp, ValueRange srcs,
188  unsigned dim) {
189  auto dstShape = dstTp.getShape();
190  sizesFromSrc(builder, sizes, loc, srcs[0]);
191 
192  // Sum up on the `dim` if the dimension is dynamic.
193  if (dstShape[dim] != ShapedType::kDynamic) {
194  // Faithfully take the static size.
195  sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196  } else {
197  // Else, compute the shape dynamically.
198  for (const auto &src : srcs.drop_front()) {
199  Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200  // Sum up all the sizes.
201  sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202  }
203  }
204 }
205 
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
209 
210 namespace {
211 
212 /// TODO: move it to tensor dialect instead.
213 ///
214 /// Fold `tensor.concat` and `tensor.extract_slice`
215 ///
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222 ///
223 /// Becomes
224 ///
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227  : public OpRewritePattern<tensor::ExtractSliceOp> {
229 
230  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231  PatternRewriter &rewriter) const override {
232  auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233  if (!concatOp)
234  return failure();
235 
236  Location loc = extractOp.getLoc();
237  int64_t dim = concatOp.getDim();
238  int64_t rank = extractOp.getResultType().getRank();
239 
240  SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241  SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242 
243  // Compute the partial sums for the slice offsets.
244  AffineExpr sum = rewriter.getAffineDimExpr(0);
245  SmallVector<AffineExpr> partialSums = {sum};
246  SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247  for (auto [idx, input] :
248  llvm::enumerate(concatOp.getInputs().drop_back())) {
249  sum = sum + rewriter.getAffineDimExpr(idx + 1);
250  partialSums.push_back(sum);
251  offsetStrides.push_back(
252  rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253  }
254  auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255  partialSums, rewriter.getContext());
256  SmallVector<OpFoldResult> dimOffsets =
258  rewriter, loc, partialSumMap, offsetStrides);
259 
260  auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261  for (auto [l, r] : llvm::zip(lhs, rhs)) {
262  std::optional<int64_t> staticVal = getConstantIntValue(l);
263  if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264  return false;
265  }
266  return lhs.size() == rhs.size();
267  };
268 
269  for (auto [i, input, offset] :
270  llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271  SmallVector<OpFoldResult> srcSizes =
272  tensor::getMixedSizes(rewriter, loc, input);
273  srcOffsets[dim] = offset;
274 
275  SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276  SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277  SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278 
279  if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280  allEqual(srcStrides, dstStrides)) {
281  Value operand = concatOp.getOperand(i);
282  if (operand.getType() == extractOp.getResultType())
283  rewriter.replaceOp(extractOp, operand);
284  break;
285  }
286  }
287 
288  return success();
289  }
290 };
291 
292 /// Rewriting rule that fuses sparse_tensor.convert into producer.
293 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294 public:
296 
297  LogicalResult matchAndRewrite(ConvertOp op,
298  PatternRewriter &rewriter) const override {
299  auto producer = op.getSource().getDefiningOp<GenericOp>();
300  if (!producer || producer.getDpsInits().size() != 1 ||
301  !isMaterializing(producer.getDpsInitOperand(0), false) ||
302  !producer.getResult(0).hasOneUse()) {
303  return failure();
304  }
305  // Clone the materialization operation, but update the result to sparse.
306  rewriter.setInsertionPoint(producer);
307  Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308  Operation *cloned = rewriter.clone(*init);
309  cloned->getResult(0).setType(op.getResult().getType());
310 
311  rewriter.modifyOpInPlace(producer, [&]() {
312  producer.getDpsInitsMutable().assign(cloned->getResults());
313  producer.getResult(0).setType(op.getResult().getType());
314  });
315 
316  rewriter.replaceAllOpUsesWith(op, producer);
317  op->erase();
318 
319  return success();
320  }
321 };
322 
323 /// Rewriting rule that converts direct yield of zero with initial allocation.
324 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325 public:
327 
328  LogicalResult matchAndRewrite(GenericOp op,
329  PatternRewriter &rewriter) const override {
330  if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331  !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332  !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333  return failure();
334  auto outputType = getRankedTensorType(op.getResult(0));
335  // Yielding zero on newly materialized sparse tensor can be
336  // optimized directly (regardless of dynamic or static size).
337  if (getSparseTensorEncoding(outputType)) {
338  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339  return success();
340  }
341  // Use static zero value directly instead of materialization.
342  if (!outputType.hasStaticShape())
343  return failure();
344  Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345  rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346  rewriter.eraseOp(def);
347  return success();
348  }
349 };
350 
351 /// Rewriting rule that converts two kernels:
352 ///
353 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354 /// X(i,j) = S(i,j) * T(i,j)
355 ///
356 /// into a single kernel, using distributive law:
357 ///
358 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
359 ///
360 /// This kind of fusion (merging two ops into one but using arithmetic
361 /// equalities that may not hold for floating-point computations) would
362 /// be undesirable in the dense case, since we distribute the multiplication
363 /// into the reduction loop. However, for sparse sampling tensor S, such
364 /// a fusion may actually reduce the asymptotic complexity of the kernel,
365 /// since intermediate results may be nullified.
366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367 public:
369 
370  LogicalResult matchAndRewrite(GenericOp op,
371  PatternRewriter &rewriter) const override {
372  // Check consumer.
373  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374  op.getNumResults() != 1 ||
375  op.getNumParallelLoops() != op.getNumLoops() ||
376  !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377  !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378  !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379  return failure();
380  // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381  // operand can be sparse or dense, since the point of this rewriting rule
382  // is detecting a situation in which *more* sparsity is introduced into
383  // a computation, be it already sparse or still dense.
384  unsigned other = 0;
385  if (isSparseTensor(op.getDpsInputOperand(0)))
386  other = 1;
387  else if (!isSparseTensor(op.getDpsInputOperand(1)))
388  return failure();
389  // Check producer.
390  auto prod = dyn_cast_or_null<GenericOp>(
391  op.getDpsInputOperand(other)->get().getDefiningOp());
392  if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393  !prod.getResult(0).hasOneUse())
394  return failure();
395  // Sampling consumer and sum of multiplication chain producer.
396  if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397  !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398  !isSampling(op) || !isSumOfMul(prod))
399  return failure();
400  // Modify operand structure of producer and consumer.
401  Location loc = prod.getLoc();
402  SmallVector<Value> inputOps = prod.getInputs();
403  SmallVector<Value> outputOps = op.getOutputs();
404  SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405  inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406  fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
407  // Fuse producer and consumer into a new generic op.
408  auto fusedOp = rewriter.create<GenericOp>(
409  loc, op.getResult(0).getType(), inputOps, outputOps,
410  rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411  /*doc=*/nullptr, /*library_call=*/nullptr);
412  Block &prodBlock = prod.getRegion().front();
413  Block &consBlock = op.getRegion().front();
414  IRMapping mapper;
415  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416  unsigned num = prodBlock.getNumArguments();
417  for (unsigned i = 0; i < num - 1; i++)
418  addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419  addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420  addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421  // Clone bodies of the producer and consumer in new evaluation order.
422  auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
423  auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
424  Value last;
425  for (auto &op : prodBlock.without_terminator())
426  if (&op != acc) {
427  last = op.getResult(0);
428  rewriter.clone(op, mapper);
429  }
430  mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
431  mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432  last = rewriter.clone(*acc, mapper)->getResult(0);
433  rewriter.create<linalg::YieldOp>(loc, last);
434  // Force initial value on merged allocation for dense outputs.
435  // TODO: deal with non alloc tensor here one day
436  if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437  Value init = prod.getDpsInitOperand(0)
438  ->get()
439  .getDefiningOp<AllocTensorOp>()
440  .getCopy();
441  AllocTensorOp a =
442  op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443  rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444  }
445  // Replace consumer with fused operation. Old producer
446  // and consumer ops will be removed by DCE.
447  rewriter.replaceOp(op, fusedOp->getResults());
448  return success();
449  }
450 
451 private:
452  // Helper to add argument and record the mapping.
453  static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454  mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
455  }
456 };
457 
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
459 // should really not be used to convert between sparse encodings. Since
460 // the pattern currently appears as a result of some prior rewriting
461 // we make an attempt to repair very obvious cases.
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464 public:
466 
467  LogicalResult matchAndRewrite(tensor::CastOp op,
468  PatternRewriter &rewriter) const override {
469  Type srcType = op.getSource().getType();
470  Type dstType = op.getDest().getType();
471  // A nop cast simply folds away.
472  if (srcType == dstType) {
473  rewriter.replaceOp(op, op->getResults());
474  return success();
475  }
476  // See if a sparsity changing cast can be fused into producer.
477  if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
478  if (Operation *def = op.getSource().getDefiningOp()) {
479  if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
480  rewriter.modifyOpInPlace(def, [&]() {
481  def->getResult(0).setType(op->getResultTypes()[0]);
482  });
483  rewriter.replaceOp(op, def->getResult(0));
484  return success();
485  }
486  }
487  }
488  // Repair tensor casts with at least one sparse operand into the
489  // the properly supported sparse_tensor.convert.
490  if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491  rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492  return success();
493  }
494  // Fail otherwise.
495  return failure();
496  }
497 };
498 
499 /// Rewrites a sequence of operations for sparse tensor selections in to
500 /// semi-ring operations such that they can be compiled correctly by the
501 /// sparsifier. E.g., transforming the following sequence
502 ///
503 /// %sel = arith.select %cond, %sp1, %sp2
504 ///
505 /// to
506 ///
507 /// %sel = binary %sp1, %sp2:
508 /// both (%l, %r) {yield select %cond, %l, %r}
509 /// left (%l) {yield select %cond, %l, 0}
510 /// right (%r) {yield select %cond, 0, %r}
511 ///
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
514 /// tri-nary operation.
515 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516 public:
518  LogicalResult matchAndRewrite(GenericOp op,
519  PatternRewriter &rewriter) const override {
520  // Rejects non sparse kernels.
521  if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522  return failure();
523 
524  Location loc = op.getLoc();
526  for (Operation &inst : *op.getBody()) {
527  // Matches pattern.
528  auto matched = isRewritablePattern(op, &inst);
529  if (!matched.has_value())
530  continue;
531 
532  rewriter.setInsertionPoint(&inst);
533  auto [c, t, f] = matched.value();
534  assert(t.getType() == f.getType());
535  auto selTp = t.getType();
536  auto c0 = constantZero(rewriter, loc, selTp);
537  auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538  // Initializes all the blocks.
539  rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540  {t.getLoc(), f.getLoc()});
541  rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542  rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543 
544  for (auto *r : binOp.getRegions()) {
545  Block *b = &r->front();
546  rewriter.setInsertionPointToStart(b);
547 
548  IRMapping irMap;
549  // Clones the cmp operations into the region to make the binary op
550  // admissible.
551  Value newC = c;
552  if (auto *def = c.getDefiningOp())
553  newC = rewriter.clone(*def, irMap)->getResult(0);
554 
555  irMap.map(c, newC);
556  if (r == &binOp.getLeftRegion()) {
557  irMap.map(t, b->getArgument(0));
558  irMap.map(f, c0);
559  } else if (r == &binOp.getRightRegion()) {
560  irMap.map(t, c0);
561  irMap.map(f, b->getArgument(0));
562  } else {
563  irMap.map(t, b->getArgument(0));
564  irMap.map(f, b->getArgument(1));
565  }
566  auto y = rewriter.clone(inst, irMap)->getResult(0);
567  rewriter.create<sparse_tensor::YieldOp>(loc, y);
568  }
569 
570  // We successfully rewrited a operation. We can not do replacement here
571  // becuase it invalidate the iterator for the current loop to traverse
572  // the instructions.
573  semiRings.emplace_back(&inst, binOp);
574  }
575 
576  // Finalizes the replacement.
577  for (auto [sel, semi] : semiRings)
578  rewriter.replaceOp(sel, semi->getResults());
579 
580  return success(!semiRings.empty());
581  }
582 
583 private:
584  static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585  isRewritablePattern(GenericOp op, Operation *v) {
586  auto sel = dyn_cast<arith::SelectOp>(v);
587  if (!sel)
588  return std::nullopt;
589 
590  auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591  auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592  // TODO: For simplicity, we only handle cases where both true/false value
593  // are directly loaded the input tensor. We can probably admit more cases
594  // in theory.
595  if (!tVal || !fVal)
596  return std::nullopt;
597 
598  // Helper lambda to determine whether the value is loaded from a dense input
599  // or is a loop invariant.
600  auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601  if (auto bArg = dyn_cast<BlockArgument>(v);
602  bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603  return true;
604  // If the value is defined outside the loop, it is a loop invariant.
605  return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606  };
607 
608  // If the condition value is load directly from a dense tensor or
609  // loop-invariants, we can sparsify the kernel.
610  auto cond = sel.getCondition();
611  if (isValFromDenseInputOrInvariant(cond))
612  return std::make_tuple(cond, tVal, fVal);
613 
614  Value cmpL, cmpR;
615  if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616  matchers::m_Any(&cmpR))) ||
617  matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618  matchers::m_Any(&cmpR)))) {
619  // TODO: we can do it recursively to check whether all the leaf values are
620  // loaded from dense tensors or are loop invariants.
621  if (isValFromDenseInputOrInvariant(cmpL) ||
622  isValFromDenseInputOrInvariant(cmpR))
623  return std::make_tuple(cond, tVal, fVal);
624  }
625 
626  return std::nullopt;
627  };
628 };
629 
630 /// Rewrites a sparse reduction that would not sparsify directly since
631 /// doing so would only iterate over the stored elements, ignoring the
632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633 /// (note that reductions like add/sub/or/xor can directly be sparsified
634 /// since the implicit zeros do not contribute to the final result).
635 /// Note that prod/and are still included since, even though they often
636 /// are nullified in sparse data, they may still occur for special
637 /// situations in which e.g. some rows in a sparse matrix are fully
638 /// dense. For min/max, including the implicit zeros is a much more
639 /// common situation.
640 ///
641 /// TODO: this essentially "densifies" the operation; we want to implement
642 /// this much more efficiently by performing the reduction over the
643 /// stored values, and feed in the zero once if there were *any*
644 /// implicit zeros as well; but for now, at least we provide
645 /// the functionality
646 ///
647 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648 public:
650 
651  LogicalResult matchAndRewrite(GenericOp op,
652  PatternRewriter &rewriter) const override {
653  // Reject non-reductions.
654  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655  op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656  return failure();
657  auto *inp = op.getDpsInputOperand(0);
658  auto *init = op.getDpsInitOperand(0);
659  if (!isSparseTensor(inp))
660  return failure();
661  // Look for direct x = x OP y for semi-ring ready reductions.
662  auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663  .getOperand(0)
664  .getDefiningOp();
665  if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666  arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667  arith::MaxUIOp>(red))
668  return failure();
669  Value s0 = op.getBlock()->getArgument(0);
670  Value s1 = op.getBlock()->getArgument(1);
671  if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672  (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673  return failure();
674  // Identity.
675  Location loc = op.getLoc();
676  Value identity =
677  rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
678  // Unary {
679  // present -> value
680  // absent -> zero.
681  // }
682  Type rtp = s0.getType();
683  rewriter.setInsertionPointToStart(&op.getRegion().front());
684  auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
685  Block *present =
686  rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687  rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688  rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689  rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690  rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691  auto zero =
692  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
693  rewriter.create<sparse_tensor::YieldOp>(loc, zero);
694  rewriter.setInsertionPointAfter(semiring);
695  // CustomReduce {
696  // x = x REDUC y, identity
697  // }
698  auto custom = rewriter.create<sparse_tensor::ReduceOp>(
699  loc, rtp, semiring.getResult(), s1, identity);
700  Block *region =
701  rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702  rewriter.setInsertionPointToStart(&custom.getRegion().front());
703  IRMapping irMap;
704  irMap.map(red->getOperand(0), region->getArgument(0));
705  irMap.map(red->getOperand(1), region->getArgument(1));
706  auto *cloned = rewriter.clone(*red, irMap);
707  rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
708  rewriter.setInsertionPointAfter(custom);
709  rewriter.replaceOp(red, custom.getResult());
710  return success();
711  }
712 };
713 
714 /// Sparse rewriting rule for the print operator. This operation is mainly used
715 /// for debugging and testing. As such, it lowers to the vector.print operation
716 /// which only require very light-weight runtime support.
717 struct PrintRewriter : public OpRewritePattern<PrintOp> {
718 public:
720  LogicalResult matchAndRewrite(PrintOp op,
721  PatternRewriter &rewriter) const override {
722  Location loc = op.getLoc();
723  auto tensor = op.getTensor();
724  auto stt = getSparseTensorType(tensor);
725  // Header with NSE.
726  auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
727  rewriter.create<vector::PrintOp>(
728  loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729  rewriter.create<vector::PrintOp>(loc, nse);
730  // Print run-time contents for dim/lvl sizes.
731  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
732  printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
733  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
734  printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
735  // Use the "codegen" foreach loop construct to iterate over
736  // all typical sparse tensor components for printing.
737  foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
738  &stt](Type, FieldIndex,
740  Level l, LevelType) {
741  switch (kind) {
742  case SparseTensorFieldKind::StorageSpec: {
743  break;
744  }
745  case SparseTensorFieldKind::PosMemRef: {
746  auto lvl = constantIndex(rewriter, loc, l);
747  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
748  rewriter.create<vector::PrintOp>(
749  loc, lvl, vector::PrintPunctuation::NoPunctuation);
750  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
751  auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
752  printContents(rewriter, loc, pos);
753  break;
754  }
755  case SparseTensorFieldKind::CrdMemRef: {
756  auto lvl = constantIndex(rewriter, loc, l);
757  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
758  rewriter.create<vector::PrintOp>(
759  loc, lvl, vector::PrintPunctuation::NoPunctuation);
760  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
761  Value crd = nullptr;
762  // For COO AoS storage, we want to print a single, linear view of
763  // the full coordinate storage at this level. For any other storage,
764  // we show the coordinate storage for every indivual level.
765  if (stt.getAoSCOOStart() == l)
766  crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
767  else
768  crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
769  printContents(rewriter, loc, crd);
770  break;
771  }
772  case SparseTensorFieldKind::ValMemRef: {
773  rewriter.create<vector::PrintOp>(loc,
774  rewriter.getStringAttr("values : "));
775  auto val = rewriter.create<ToValuesOp>(loc, tensor);
776  printContents(rewriter, loc, val);
777  break;
778  }
779  }
780  return true;
781  });
782  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
783  rewriter.eraseOp(op);
784  return success();
785  }
786 
787 private:
788  // Helper to print contents of a single memref. For "push_back" vectors,
789  // we assume that the previous getters for pos/crd/val have added a
790  // slice-to-size view to make sure we just print the size and not the
791  // full capacity.
792  //
793  // Generates code to print (1-dim or higher):
794  // ( a0, a1, ... )
795  static void printContents(PatternRewriter &rewriter, Location loc,
796  Value vec) {
797  auto shape = cast<ShapedType>(vec.getType()).getShape();
798  SmallVector<Value> idxs;
799  printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
801  }
802 
803  // Helper to the helper.
804  static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805  Value vec, unsigned i, ArrayRef<int64_t> shape,
806  SmallVectorImpl<Value> &idxs) {
807  // Open bracket.
808  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
809  // Generate for loop.
810  auto zero = constantIndex(rewriter, loc, 0);
811  auto index = constantIndex(rewriter, loc, i);
812  auto size = rewriter.create<memref::DimOp>(loc, vec, index);
813  auto step = constantIndex(rewriter, loc, 1);
814  auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815  idxs.push_back(forOp.getInductionVar());
816  rewriter.setInsertionPointToStart(forOp.getBody());
817  if (i < shape.size() - 1) {
818  // Enter deeper loop nest.
819  printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
820  } else {
821  // Actual contents printing.
822  auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823  if (llvm::isa<ComplexType>(val.getType())) {
824  // Since the vector dialect does not support complex types in any op,
825  // we split those into (real, imag) pairs here.
826  Value real = rewriter.create<complex::ReOp>(loc, val);
827  Value imag = rewriter.create<complex::ImOp>(loc, val);
828  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829  rewriter.create<vector::PrintOp>(loc, real,
830  vector::PrintPunctuation::Comma);
831  rewriter.create<vector::PrintOp>(loc, imag,
832  vector::PrintPunctuation::Close);
833  } else {
834  rewriter.create<vector::PrintOp>(
835  loc, val, vector::PrintPunctuation::NoPunctuation);
836  }
837  // Terminating comma (except at end).
838  auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
839  Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
840  bound, size);
841  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
842  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
843  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
844  }
845  idxs.pop_back();
846  rewriter.setInsertionPointAfter(forOp);
847  // Close bracket.
848  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
849  }
850 
851  // Helper method to print run-time lvl/dim sizes.
852  static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
853  unsigned size, bool isDim) {
854  // Open bracket.
855  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
856  // Print unrolled contents (dimop requires constant value).
857  for (unsigned i = 0; i < size; i++) {
858  auto idx = constantIndex(rewriter, loc, i);
859  Value val;
860  if (isDim)
861  val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
862  else
863  val = rewriter.create<LvlOp>(loc, tensor, idx);
864  rewriter.create<vector::PrintOp>(
865  loc, val,
866  i != size - 1 ? vector::PrintPunctuation::Comma
867  : vector::PrintPunctuation::NoPunctuation);
868  }
869  // Close bracket and end of line.
870  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
872  }
873 };
874 
875 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
876 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
877 public:
879 
880  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
881  PatternRewriter &rewriter) const override {
882  Location loc = op.getLoc();
883  Value srcTensor = op.getSource();
884  const auto srcTp = tryGetSparseTensorType(srcTensor);
885  const auto dstTp = tryGetSparseTensorType(op.getResult());
886  if (!srcTp || !dstTp)
887  return failure();
888 
889  if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890  !dstTp->hasStaticDimShape())
891  return failure();
892 
893  SmallVector<Value> srcSizes;
894  sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
895  SmallVector<Value> dstSizes;
896  for (Dimension d : dstTp->getDimShape())
897  dstSizes.push_back(constantIndex(rewriter, loc, d));
898 
899  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
900  // Only need an unordered COO buffer if input and output are not sorted
901  // in the same way.
902  Type bufferTp = getBufferType(
903  dstTp->withoutDimToLvl(),
904  !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
905  SmallVector<Value> dynSizes;
906  Value buffer = rewriter
907  .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
908  nnz, Attribute())
909  .getResult();
910 
911  // Convert src coordinates to dst coordinates by first collapsing it to 1D
912  // and then expand it to the match the rank of the destination tensor.
913  // Implemented as follows:
914  // foreach srcCoords %srcTensor
915  // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916  // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917  // insert expandedCoords, %buffer
918  //
919  // followed by an optional
920  // %t = sparse_tensor.cast %tmp
921  // depending on whether the input/output are sorted in the same way.
922  const auto encSrc = srcTp->getEncoding();
923  ForeachOp foreachOp = rewriter.create<ForeachOp>(
924  loc, srcTensor, buffer,
925  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
926  ValueRange reduc) {
927  const Dimension srcRank = srcTp->getDimRank();
928  SmallVector<Value> srcDcvs;
929  srcDcvs.reserve(srcRank);
930  for (Dimension d = 0; d < srcRank; d++) {
931  Level lvl = toLvl(encSrc, d);
932  srcDcvs.push_back(srcLcvs[lvl]);
933  }
934 
935  Value collapseSize = constantIndex(builder, loc, 1);
936  for (Dimension d = 0; d < srcRank; d++)
937  collapseSize =
938  builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
939  SmallVector<Value, 1> collapsedSizes = {collapseSize};
940 
941  ReassociationIndices collapseIdx;
942  for (Dimension i = 0; i < srcRank; i++)
943  collapseIdx.push_back(i);
944  SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
945  SmallVector<Value, 1> collapsedDcvs;
946  reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947  collapsedSizes, collapsedDcvs);
948 
949  ReassociationIndices expandIdx;
950  for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951  expandIdx.push_back(i);
952  SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
953  SmallVector<Value> dstDcvs;
954  reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955  dstSizes, dstDcvs);
956 
957  auto t =
958  builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
959  builder.create<sparse_tensor::YieldOp>(loc, t);
960  });
961 
962  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
963  if (bufferTp != *dstTp) {
964  auto dstRTT = dstTp->getRankedTensorType();
965  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
966  rewriter.create<DeallocTensorOp>(loc, t);
967  t = converted;
968  }
969  rewriter.replaceOp(op, t);
970  return success();
971  }
972 };
973 
974 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
975 template <typename ReshapeOp>
976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977 public:
979 
980  LogicalResult matchAndRewrite(ReshapeOp op,
981  PatternRewriter &rewriter) const override {
982  Location loc = op.getLoc();
983  Value srcTensor = op.getSrc();
984  const auto srcTp = getSparseTensorType(srcTensor);
985  const auto dstTp = getSparseTensorType(op.getResult());
986  if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987  return failure();
988 
989  // Generate code to represent the static dimension constants or compute
990  // the dynamic dimension values.
991  SmallVector<Value> srcSizes;
992  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
993  SmallVector<Value> dstSizes;
994  SmallVector<Value> dstDynSizes;
995  if (dstTp.hasStaticDimShape()) {
996  for (Dimension d : dstTp.getDimShape())
997  dstSizes.push_back(constantIndex(rewriter, loc, d));
998  } else {
999  ArrayRef<Size> dstShape = dstTp.getDimShape();
1000  genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001  op.getReassociationIndices());
1002  for (auto [idx, shape] : llvm::enumerate(dstShape)) {
1003  if (shape == ShapedType::kDynamic)
1004  dstDynSizes.push_back(dstSizes[idx]);
1005  }
1006  }
1007  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
1008  // Only need a unordered COO buffer if input and output are not sorted
1009  // in the same way.
1010  Type bufferTp = getBufferType(
1011  dstTp.withoutDimToLvl(),
1012  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013 
1014  Value buffer =
1015  rewriter
1016  .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
1017  /*sizeHint=*/nnz, Attribute())
1018  .getResult();
1019 
1020  // Implement the sparse2sparse reshape as follows:
1021  // foreach srcCoords %srcTensor
1022  // insert reshapeCvs(srcCoords), %buffer
1023  //
1024  // followed by an optional
1025  // %t = sparse_tensor.cast %tmp
1026  // depending on whether the input/output are sorted in the same way.
1027  const auto encSrc = srcTp.getEncoding();
1028  ForeachOp foreachOp = rewriter.create<ForeachOp>(
1029  loc, srcTensor, buffer,
1030  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1031  ValueRange reduc) {
1032  const Dimension dimRank = srcTp.getDimRank();
1033  SmallVector<Value> srcDcvs;
1034  srcDcvs.reserve(dimRank);
1035  for (Dimension d = 0; d < dimRank; d++) {
1036  Level lvl = toLvl(encSrc, d);
1037  srcDcvs.push_back(srcLcvs[lvl]);
1038  }
1039  SmallVector<Value> dstDcvs;
1040  reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1041  srcDcvs, dstSizes, dstDcvs);
1042  auto t =
1043  builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1044  builder.create<sparse_tensor::YieldOp>(loc, t);
1045  });
1046 
1047  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
1048  if (bufferTp != dstTp) {
1049  auto dstRTT = dstTp.getRankedTensorType();
1050  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
1051  rewriter.create<DeallocTensorOp>(loc, t);
1052  t = converted;
1053  }
1054  rewriter.replaceOp(op, t);
1055  return success();
1056  }
1057 };
1058 
1059 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1060 /// operator.
1061 template <typename ReshapeOp>
1062 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1063 public:
1065 
1066  LogicalResult matchAndRewrite(ReshapeOp op,
1067  PatternRewriter &rewriter) const override {
1068  Location loc = op->getLoc();
1069  auto encDst = getSparseTensorEncoding(op.getResult().getType());
1070  auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1071  // Since a pure dense expansion is very cheap (change of view), for
1072  // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1073  // conversion from the reshape operation itself.
1074  // All other cases are handled elsewhere.
1075  if (encDst && encSrc) {
1076  return failure();
1077  }
1078  if (encSrc) {
1079  auto rtp = getRankedTensorType(op.getSrc());
1080  auto denseTp =
1081  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1082  auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
1083  rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1084  return success();
1085  }
1086  if (encDst) {
1087  auto rtp = getRankedTensorType(op.getResult());
1088  auto denseTp =
1089  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1090  ReshapeOp reshape;
1091  if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1092  reshape = rewriter.create<ReshapeOp>(
1093  loc, denseTp, op.getSrc(), op.getReassociation(),
1094  op.getOutputShape(), op.getStaticOutputShape());
1095  } else {
1096  reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
1097  op.getReassociation());
1098  }
1099  Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
1100  rewriter.replaceOp(op, convert);
1101  return success();
1102  }
1103  return failure();
1104  }
1105 };
1106 
1107 // A trivial wrapper to help generate different operations for dense/sparse
1108 // tensors.
1109 struct TensorLike {
1110  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1111  ValueRange sizes) {
1112  SmallVector<Value> dynSzs;
1113  getDynamicSizes(rtt, sizes, dynSzs);
1114 
1115  val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
1116  if (!isSparse()) {
1117  Value c0 = constantZero(builder, loc, rtt.getElementType());
1118  val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
1119  }
1120  }
1121 
1122  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1123  val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1124  }
1125 
1126  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1127  if (isSparse())
1128  return builder.create<LoadOp>(loc, val, true);
1129  return val;
1130  }
1131 
1132  bool isSparse() const {
1133  return getSparseTensorEncoding(val.getType()) != nullptr;
1134  }
1135 
1136  Value val;
1137 };
1138 
1139 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1141  LogicalResult matchAndRewrite(tensor::DimOp op,
1142  PatternRewriter &rewriter) const override {
1143  std::optional<int64_t> dim = op.getConstantIndex();
1144  auto stt = tryGetSparseTensorType(op.getSource());
1145  if (!dim || !stt || !stt->hasEncoding())
1146  return failure();
1147 
1148  if (stt->isPermutation()) {
1149  rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1150  toLvl(stt->getEncoding(), *dim));
1151  return success();
1152  }
1153 
1154  // Non-permutation dim2lvl/lvl2dim maps.
1155  // Compute as follows:
1156  // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1157  // Note that it is not the most efficient way (but a more general one) for
1158  // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1159  // computed simply by lvl_size * block_size.
1160  Location loc = op.getLoc();
1161  SmallVector<Value> maxLvlCrds;
1162  for (Level l = 0; l < stt->getLvlRank(); l++) {
1163  Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1164  Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1165  loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1166  maxLvlCrds.push_back(maxLvlCrd);
1167  }
1168 
1169  AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170  Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1171  op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172  maxLvlCrds);
1173 
1174  Value dimSz = rewriter.create<arith::AddIOp>(
1175  loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1176  rewriter.replaceOp(op, dimSz);
1177  return success();
1178  }
1179 };
1180 
1181 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1183  LogicalResult matchAndRewrite(ConcatenateOp op,
1184  PatternRewriter &rewriter) const override {
1185  if (op.needsExtraSort())
1186  op.emitError("ConcatenateOp not staged");
1187 
1188  const Location loc = op.getLoc();
1189  const auto dstTp = getSparseTensorType(op);
1190  const Dimension conDim = op.getDimension();
1191  SmallVector<Value> sizes;
1192  concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1193 
1194  // %t = concatenate %s1, %s2, %s3 {dim = 1}
1195  // ==>
1196  // if (isSparseDst)
1197  // if (allDense)
1198  // %tmp = bufferization.alloc_tensor dstTp
1199  // else
1200  // %tmp = bufferization.alloc_tensor : unordered COO
1201  // else
1202  // %tmp = memref.alloc : dense tensor
1203  // foreach in %s1 : insert d0, d1, %tmp
1204  // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1205  // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1206 
1207  TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1208  Value offset = constantIndex(rewriter, loc, 0);
1209  Value iterArg = dstBuf.val;
1210 
1211  ForeachOp foreachOp;
1212  for (Value input : op.getInputs()) {
1213  // Builds a for op for each input tensor to append new values into the
1214  // output tensor.
1215  foreachOp = rewriter.create<ForeachOp>(
1216  loc, input, iterArg,
1217  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1218  ValueRange reduc) {
1219  SmallVector<Value> offDimCrd(dcvs);
1220  offDimCrd[conDim] =
1221  builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1222 
1223  // Enters foreach, updates the SSA chain.
1224  dstBuf.val = reduc.front();
1225  if (!dstTp.isAllDense()) {
1226  Value cond = genIsNonzero(builder, loc, v);
1227  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1228  /*else*/ true);
1229  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1230  builder.create<scf::YieldOp>(loc, dstBuf.val);
1231 
1232  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233  dstBuf.insert(builder, loc, v, offDimCrd);
1234  builder.create<scf::YieldOp>(loc, dstBuf.val);
1235 
1236  // Exits the ifOp, update the sparse tensor SSA value.
1237  builder.setInsertionPointAfter(ifOp);
1238  dstBuf.val = ifOp.getResult(0);
1239  } else {
1240  dstBuf.insert(builder, loc, v, offDimCrd);
1241  }
1242  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1243  });
1244  // Accumulates the offset. Note that only static-shaped inputs are allowed
1245  // by concatenate op verifier, which saves us from computing the offset
1246  // dynamically.
1247  const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1248  assert(!ShapedType::isDynamic(sz));
1249  offset = rewriter.create<arith::AddIOp>(loc, offset,
1250  constantIndex(rewriter, loc, sz));
1251  iterArg = foreachOp.getResult(0);
1252  dstBuf.val = iterArg;
1253  }
1254 
1255  dstBuf.val = iterArg;
1256  Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1257  rewriter.replaceOp(op, ret);
1258  return success();
1259  }
1260 };
1261 
1262 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1264  LogicalResult matchAndRewrite(ConvertOp op,
1265  PatternRewriter &rewriter) const override {
1266  if (op.needsExtraSort())
1267  return op.emitError("ConvertOp not staged.");
1268 
1269  // TODO: Maybe we want a different operation for this too.
1270  auto encDst = getSparseTensorEncoding(op.getType());
1271  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1272  if (encDst && encSrc && !encSrc.isSlice() &&
1273  encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1274  // Trivial tensor conversion and simple element type conversion is handled
1275  // in codegen.
1276  return failure();
1277  }
1278 
1279  Location loc = op.getLoc();
1280  Value src = op.getSource();
1281 
1282  SparseTensorType srcStt = getSparseTensorType(op.getSource());
1283  SparseTensorType dstStt = getSparseTensorType(op.getDest());
1284 
1285  bool fromSparseConst = false;
1286  if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1287  if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1288  fromSparseConst = true;
1289 
1290  const AffineMapAttr foreachOrder =
1291  (!dstStt.isIdentity() && fromSparseConst)
1293  : nullptr;
1294 
1295  bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1296 
1297  SmallVector<Value> sizes;
1298  sizesFromSrc(rewriter, sizes, loc, src);
1299  ValueRange vs;
1300  TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1301 
1302  auto foreachOp = rewriter.create<ForeachOp>(
1303  loc, src, dstBuf.val, foreachOrder,
1304  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1305  ValueRange reduc) {
1306  // Enters the loop, update the SSA value for insertion chain.
1307  dstBuf.val = reduc.front();
1308  if (!skipZeroCheck) {
1309  Value cond = genIsNonzero(builder, loc, v);
1310  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1311  /*else*/ true);
1312  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1313  builder.create<scf::YieldOp>(loc, dstBuf.val);
1314 
1315  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316  dstBuf.insert(builder, loc, v, dcvs);
1317  builder.create<scf::YieldOp>(loc, dstBuf.val);
1318 
1319  // Exits the ifOp, update the sparse tensor SSA value.
1320  builder.setInsertionPointAfter(ifOp);
1321  dstBuf.val = ifOp.getResult(0);
1322  } else {
1323  dstBuf.insert(builder, loc, v, dcvs);
1324  }
1325  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1326  });
1327 
1328  rewriter.setInsertionPointAfter(foreachOp);
1329 
1330  // Exits the for loop, links the SSA chain.
1331  dstBuf.val = foreachOp.getResult(0);
1332 
1333  Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1334  rewriter.replaceOp(op, ret);
1335  return success();
1336  }
1337 };
1338 
1339 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1341  LogicalResult matchAndRewrite(CrdTranslateOp op,
1342  PatternRewriter &rewriter) const override {
1343  AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1344  ? op.getEncoder().getDimToLvl()
1345  : op.getEncoder().getLvlToDim();
1346 
1347  SmallVector<Value> outCrds;
1348  for (AffineExpr result : map.getResults()) {
1349  // TODO: we should probably expand the affine map to IR using our own
1350  // rules, since affine.apply assume signed value, while the cooridinates
1351  // we provided must always be signless.
1352  Value trans = rewriter.create<affine::AffineApplyOp>(
1353  op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1354  op.getInCrds());
1355  outCrds.push_back(trans);
1356  }
1357  rewriter.replaceOp(op, outCrds);
1358  return success();
1359  }
1360 };
1361 
1362 /// Sparse rewriting rule for the foreach operator.
1363 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1364 public:
1366 
1367  LogicalResult matchAndRewrite(ForeachOp op,
1368  PatternRewriter &rewriter) const override {
1369 
1370  auto loc = op.getLoc();
1371  Value input = op.getTensor();
1372  SmallVector<Value> reduc = op.getInitArgs();
1373  const auto stt = getSparseTensorType(input);
1374  const Level lvlRank = stt.getLvlRank();
1375 
1376  // Special-case: for each over a sparse constant uses its own rewriting
1377  // rule.
1378  if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1379  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1380  return genForeachOnSparseConstant(op, rewriter, attr);
1381  }
1382  }
1383 
1384  // Otherwise, use loop emitter to generate loops.
1385  const auto enc = stt.getEncoding();
1386 
1387  // 1. Generates loop for the sparse input.
1388  LoopEmitter loopEmitter(
1389  ValueRange{input},
1390  StringAttr::get(getContext(), ForeachOp::getOperationName()));
1391  loopEmitter.initializeLoopEmit(rewriter, loc);
1392  for (Level l = 0; l < lvlRank; l++) {
1393  // TODO: provide utility function for loop sequences that only contains
1394  // one for loop?
1395  const SmallVector<TensorLevel, 1> tidLvls{
1396  loopEmitter.makeTensorLevel(0, l)};
1397  loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1398  // Note that reduc will be taken care of by loop emitter and get updated
1399  // in place.
1400  loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1401  reduc);
1402  }
1403 
1404  SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1405  if (op.getOrder()) {
1406  // TODO: Support it so that we can do direct conversion from CSR->BSR.
1407  llvm_unreachable(
1408  "Level order not yet implemented on non-constant input tensors.");
1409  }
1410 
1411  Value vals = loopEmitter.getValBuffer()[0];
1412  SmallVector<Value> pos = loopEmitter.getValPosits(0);
1413  // Loads the value from sparse tensor using position-index;
1414  // loads the value from dense tensor using coords.
1415  Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1416  : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1417 
1418  // 2. Inline the block in the foreach operator.
1419  Block *srcBlock = op.getBody();
1420 
1421  // Remap coordinates.
1422  SmallVector<Value> args =
1423  enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1424 
1425  // Remap value.
1426  args.push_back(val);
1427  // Remap reduction variables.
1428  args.append(reduc);
1429 
1430  // Remove sparse_tensor.yield.
1431  SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1432  rewriter.eraseOp(srcBlock->getTerminator());
1433 
1434  Operation &last = rewriter.getBlock()->back();
1435  if (llvm::isa<scf::YieldOp>(last)) {
1436  // Because `scf.for` inserts an implicit yield op when there is no
1437  // reduction variable upon creation, we reset the insertion point such
1438  // that the block is inlined before *before* the yield op.
1439  rewriter.setInsertionPoint(&last);
1440  }
1441 
1442  rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1443  rewriter.getInsertionPoint(), args);
1444  rewriter.setInsertionPointToEnd(rewriter.getBlock());
1445  for (Level l = 0; l < lvlRank; l++) {
1446  // Link the reduction chain. Note that loop emitter update the reducValue
1447  // in place.
1448  loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1449  loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1450  }
1451 
1452  // Replace the foreach operator with the value returned by the outtermost
1453  // for loop.
1454  rewriter.replaceOp(op, reducValue);
1455  return success();
1456  }
1457 };
1458 
1459 /// Sparse rewriting rule for the new operator.
1460 struct NewRewriter : public OpRewritePattern<NewOp> {
1462  LogicalResult matchAndRewrite(NewOp op,
1463  PatternRewriter &rewriter) const override {
1464  Location loc = op.getLoc();
1465  auto stt = getSparseTensorType(op.getResult());
1466  if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1467  return failure();
1468 
1469  // Implement the NewOp as follows:
1470  // %orderedCoo = sparse_tensor.new %filename
1471  // %t = sparse_tensor.convert %orderedCoo
1472  // with enveloping reinterpreted_map ops for non-permutations.
1473  RankedTensorType dstTp = stt.getRankedTensorType();
1474  RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1475  Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1476  Value convert = cooTensor;
1477  auto enc = stt.getEncoding();
1478  if (!stt.isPermutation()) { // demap coo, demap dstTp
1479  auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1480  convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1481  dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1482  }
1483  convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1484  if (!stt.isPermutation()) // remap to original enc
1485  convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1486  rewriter.replaceOp(op, convert);
1487 
1488  // Release the temporary ordered COO tensor.
1489  rewriter.setInsertionPointAfterValue(convert);
1490  rewriter.create<DeallocTensorOp>(loc, cooTensor);
1491 
1492  return success();
1493  }
1494 };
1495 
1496 /// Sparse rewriting rule for the out operator.
1497 struct OutRewriter : public OpRewritePattern<OutOp> {
1499  LogicalResult matchAndRewrite(OutOp op,
1500  PatternRewriter &rewriter) const override {
1501  Location loc = op.getLoc();
1502  // Calculate NNZ.
1503  Value src = op.getTensor();
1504  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1505 
1506  // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1507  const auto srcTp = getSparseTensorType(src);
1508  const Dimension dimRank = srcTp.getDimRank();
1509  Type indexTp = rewriter.getIndexType();
1510  Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1511 
1512  // Generate code to calculate dimension size values and store the values to
1513  // the buffer.
1514  SmallVector<Value> dims;
1515  sizesForTensor(rewriter, dims, loc, srcTp, src);
1516  for (Dimension d = 0; d < dimRank; d++) {
1517  rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1518  constantIndex(rewriter, loc, d));
1519  }
1520 
1521  // Create a sparse tensor writer and output meta data.
1522  Type opaqueTp = getOpaquePointerType(rewriter);
1523  Value writer =
1524  createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1525  {op.getDest()}, EmitCInterface::Off)
1526  .getResult(0);
1527  Value rankValue = constantIndex(rewriter, loc, dimRank);
1528  createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1529  {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1530 
1531  Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1532  Type eltTp = srcTp.getElementType();
1533  SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1534  primaryTypeFunctionSuffix(eltTp)};
1535  Value value = genAllocaScalar(rewriter, loc, eltTp);
1536  ModuleOp module = op->getParentOfType<ModuleOp>();
1537 
1538  // For each element in the source tensor, output the element.
1539  rewriter.create<ForeachOp>(
1540  loc, src, std::nullopt,
1541  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1542  ValueRange reduc) {
1543  for (Dimension d = 0; d < dimRank; d++) {
1544  rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1545  constantIndex(builder, loc, d));
1546  }
1547  rewriter.create<memref::StoreOp>(loc, v, value);
1548  SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1549  FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1550  EmitCInterface::On);
1551  builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1552  builder.create<sparse_tensor::YieldOp>(loc);
1553  });
1554 
1555  // Release the writer.
1556  createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1557  EmitCInterface::Off);
1558 
1559  rewriter.eraseOp(op);
1560  return success();
1561  }
1562 };
1563 
1564 } // namespace
1565 
1566 //===---------------------------------------------------------------------===//
1567 // Methods that add patterns described in this file to a pattern list.
1568 //===---------------------------------------------------------------------===//
1569 
1571  patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1572  FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1573  GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1574  patterns.getContext());
1575 }
1576 
1578  bool enableRT,
1579  bool enableConvert) {
1580  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581  ReshapeRewriter<tensor::CollapseShapeOp>,
1582  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1584  SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1585  patterns.getContext());
1586 
1587  if (enableConvert)
1588  patterns.add<DirectConvertRewriter>(patterns.getContext());
1589  if (!enableRT)
1590  patterns.add<NewRewriter>(patterns.getContext());
1591 }
1592 
1594  // Run CrdTranslateRewriter later in the pipeline so that operation can be
1595  // folded before lowering to affine.apply
1596  patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1597 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:293
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:355
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:395
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:349
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:579
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:461
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:456
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:90
auto m_Any()
Definition: Matchers.h:532
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:323
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:160
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:186
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:123
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:437
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238