MLIR  19.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
15 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
32 
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
37 
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
41 
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45 }
46 
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49  auto enc = getSparseTensorEncoding(v.getType());
50  return enc && !llvm::all_of(enc.getLvlTypes(),
51  [](auto lt) { return lt == LevelFormat::Dense; });
52 }
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54 
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57  Value val = op->get();
58  // Check allocation, with zero alloc when required.
59  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60  Value copy = alloc.getCopy();
61  if (isZero)
62  return copy && isZeroValue(copy);
63  return !copy;
64  }
65  // Check for empty tensor materialization.
66  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67  return !isZero;
68  // Last resort for zero alloc: the whole value is zero.
69  return isZero && isZeroValue(val);
70 }
71 
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77  // Both scalar input arguments used exactly once.
78  Value s1 = op.getBlock()->getArgument(0);
79  Value s2 = op.getBlock()->getArgument(1);
80  return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81  (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82  }
83  }
84  return false;
85 }
86 
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89  if (auto arg = dyn_cast<BlockArgument>(val))
90  return arg != x;
91  if (auto *def = val.getDefiningOp()) {
92  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93  return isMulChain(def->getOperand(0), x) &&
94  isMulChain(def->getOperand(1), x);
95  }
96  return false;
97 }
98 
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104  Value x = op.getBlock()->getArguments().back();
105  return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106  (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107  }
108  }
109  return false;
110 }
111 
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115  if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116  if (arg.getOwner()->getParentOp() == op) {
117  return isZeroValue(op->getOperand(arg.getArgNumber()));
118  }
119  }
120  return isZeroValue(yieldOp.getOperand(0));
121 }
122 
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126  Location loc, ShapedType stp, Value tensor) {
127  for (const auto &d : enumerate(stp.getShape())) {
128  Value dim;
129  if (d.value() == ShapedType::kDynamic)
130  dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131  else
132  dim = constantIndex(builder, loc, d.value());
133  sizes.push_back(dim);
134  }
135 }
136 
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138  bool needTmpCOO) {
139  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140  : stt.getRankedTensorType();
141 }
142 
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147  SmallVectorImpl<Value> &dynSizes) {
148  for (const auto &d : enumerate(tp.getShape())) {
149  if (d.value() == ShapedType::kDynamic)
150  dynSizes.push_back(sizes[d.index()]);
151  }
152 }
153 
154 static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155  RewriterBase &rewriter,
156  SparseElementsAttr attr) {
157  auto loc = op.getLoc();
158  SmallVector<Value> reduc = op.getInitArgs();
159 
160  // Foreach on constant.
162  rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163  [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164  SmallVector<Value> args;
165  args.append(cvs.begin(), cvs.end());
166  args.push_back(v);
167  args.append(reduc);
168  // Clones the foreach op to get a copy of the loop body.
169  auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170  assert(args.size() == cloned.getBody()->getNumArguments());
171  Operation *yield = cloned.getBody()->getTerminator();
172  rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173  // clean up
174  rewriter.eraseOp(cloned);
175  reduc = yield->getOperands();
176  rewriter.eraseOp(yield);
177  });
178 
179  rewriter.replaceOp(op, reduc);
180  return success();
181 }
182 
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186  SmallVectorImpl<Value> &sizes, Location loc,
187  ShapedType dstTp, ValueRange srcs,
188  unsigned dim) {
189  auto dstShape = dstTp.getShape();
190  sizesFromSrc(builder, sizes, loc, srcs[0]);
191 
192  // Sum up on the `dim` if the dimension is dynamic.
193  if (dstShape[dim] != ShapedType::kDynamic) {
194  // Faithfully take the static size.
195  sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196  } else {
197  // Else, compute the shape dynamically.
198  for (const auto &src : srcs.drop_front()) {
199  Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200  // Sum up all the sizes.
201  sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202  }
203  }
204 }
205 
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
209 
210 namespace {
211 
212 /// TODO: move it to tensor dialect instead.
213 ///
214 /// Fold `tensor.concat` and `tensor.extract_slice`
215 ///
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222 ///
223 /// Becomes
224 ///
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227  : public OpRewritePattern<tensor::ExtractSliceOp> {
229 
230  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231  PatternRewriter &rewriter) const override {
232  auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233  if (!concatOp)
234  return failure();
235 
236  Location loc = extractOp.getLoc();
237  int64_t dim = concatOp.getDim();
238  int64_t rank = extractOp.getResultType().getRank();
239 
240  SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241  SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242 
243  // Compute the partial sums for the slice offsets.
244  AffineExpr sum = rewriter.getAffineDimExpr(0);
245  SmallVector<AffineExpr> partialSums = {sum};
246  SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247  for (auto [idx, input] :
248  llvm::enumerate(concatOp.getInputs().drop_back())) {
249  sum = sum + rewriter.getAffineDimExpr(idx + 1);
250  partialSums.push_back(sum);
251  offsetStrides.push_back(
252  rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253  }
254  auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255  partialSums, rewriter.getContext());
256  SmallVector<OpFoldResult> dimOffsets =
258  rewriter, loc, partialSumMap, offsetStrides);
259 
260  auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261  for (auto [l, r] : llvm::zip(lhs, rhs)) {
262  std::optional<int64_t> staticVal = getConstantIntValue(l);
263  if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264  return false;
265  }
266  return lhs.size() == rhs.size();
267  };
268 
269  for (auto [i, input, offset] :
270  llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271  SmallVector<OpFoldResult> srcSizes =
272  tensor::getMixedSizes(rewriter, loc, input);
273  srcOffsets[dim] = offset;
274 
275  SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276  SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277  SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278 
279  if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280  allEqual(srcStrides, dstStrides)) {
281  Value operand = concatOp.getOperand(i);
282  if (operand.getType() == extractOp.getResultType())
283  rewriter.replaceOp(extractOp, operand);
284  break;
285  }
286  }
287 
288  return success();
289  }
290 };
291 
292 /// Rewriting rule that fuses sparse_tensor.convert into producer.
293 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294 public:
296 
297  LogicalResult matchAndRewrite(ConvertOp op,
298  PatternRewriter &rewriter) const override {
299  auto producer = op.getSource().getDefiningOp<GenericOp>();
300  if (!producer || producer.getDpsInits().size() != 1 ||
301  !isMaterializing(producer.getDpsInitOperand(0), false) ||
302  !producer.getResult(0).hasOneUse()) {
303  return failure();
304  }
305  // Clone the materialization operation, but update the result to sparse.
306  rewriter.setInsertionPoint(producer);
307  Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308  Operation *cloned = rewriter.clone(*init);
309  cloned->getResult(0).setType(op.getResult().getType());
310 
311  rewriter.modifyOpInPlace(producer, [&]() {
312  producer.getDpsInitsMutable().assign(cloned->getResults());
313  producer.getResult(0).setType(op.getResult().getType());
314  });
315 
316  rewriter.replaceAllOpUsesWith(op, producer);
317  op->erase();
318 
319  return success();
320  }
321 };
322 
323 /// Rewriting rule that converts direct yield of zero with initial allocation.
324 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325 public:
327 
328  LogicalResult matchAndRewrite(GenericOp op,
329  PatternRewriter &rewriter) const override {
330  if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331  !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332  !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333  return failure();
334  auto outputType = getRankedTensorType(op.getResult(0));
335  // Yielding zero on newly materialized sparse tensor can be
336  // optimized directly (regardless of dynamic or static size).
337  if (getSparseTensorEncoding(outputType)) {
338  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339  return success();
340  }
341  // Use static zero value directly instead of materialization.
342  if (!outputType.hasStaticShape())
343  return failure();
344  Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345  rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346  rewriter.eraseOp(def);
347  return success();
348  }
349 };
350 
351 /// Rewriting rule that converts two kernels:
352 ///
353 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354 /// X(i,j) = S(i,j) * T(i,j)
355 ///
356 /// into a single kernel, using distributive law:
357 ///
358 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
359 ///
360 /// This kind of fusion (merging two ops into one but using arithmetic
361 /// equalities that may not hold for floating-point computations) would
362 /// be undesirable in the dense case, since we distribute the multiplication
363 /// into the reduction loop. However, for sparse sampling tensor S, such
364 /// a fusion may actually reduce the asymptotic complexity of the kernel,
365 /// since intermediate results may be nullified.
366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367 public:
369 
370  LogicalResult matchAndRewrite(GenericOp op,
371  PatternRewriter &rewriter) const override {
372  // Check consumer.
373  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374  op.getNumResults() != 1 ||
375  op.getNumParallelLoops() != op.getNumLoops() ||
376  !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377  !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378  !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379  return failure();
380  // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381  // operand can be sparse or dense, since the point of this rewriting rule
382  // is detecting a situation in which *more* sparsity is introduced into
383  // a computation, be it already sparse or still dense.
384  unsigned other = 0;
385  if (isSparseTensor(op.getDpsInputOperand(0)))
386  other = 1;
387  else if (!isSparseTensor(op.getDpsInputOperand(1)))
388  return failure();
389  // Check producer.
390  auto prod = dyn_cast_or_null<GenericOp>(
391  op.getDpsInputOperand(other)->get().getDefiningOp());
392  if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393  !prod.getResult(0).hasOneUse())
394  return failure();
395  // Sampling consumer and sum of multiplication chain producer.
396  if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397  !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398  !isSampling(op) || !isSumOfMul(prod))
399  return failure();
400  // Modify operand structure of producer and consumer.
401  Location loc = prod.getLoc();
402  SmallVector<Value> inputOps = prod.getInputs();
403  SmallVector<Value> outputOps = op.getOutputs();
404  SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405  inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406  fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
407  // Fuse producer and consumer into a new generic op.
408  auto fusedOp = rewriter.create<GenericOp>(
409  loc, op.getResult(0).getType(), inputOps, outputOps,
410  rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411  /*doc=*/nullptr, /*library_call=*/nullptr);
412  Block &prodBlock = prod.getRegion().front();
413  Block &consBlock = op.getRegion().front();
414  IRMapping mapper;
415  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416  unsigned num = prodBlock.getNumArguments();
417  for (unsigned i = 0; i < num - 1; i++)
418  addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419  addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420  addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421  // Clone bodies of the producer and consumer in new evaluation order.
422  auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
423  auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
424  Value last;
425  for (auto &op : prodBlock.without_terminator())
426  if (&op != acc) {
427  last = op.getResult(0);
428  rewriter.clone(op, mapper);
429  }
430  mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
431  mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432  last = rewriter.clone(*acc, mapper)->getResult(0);
433  rewriter.create<linalg::YieldOp>(loc, last);
434  // Force initial value on merged allocation for dense outputs.
435  // TODO: deal with non alloc tensor here one day
436  if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437  Value init = prod.getDpsInitOperand(0)
438  ->get()
439  .getDefiningOp<AllocTensorOp>()
440  .getCopy();
441  AllocTensorOp a =
442  op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443  rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444  }
445  // Replace consumer with fused operation. Old producer
446  // and consumer ops will be removed by DCE.
447  rewriter.replaceOp(op, fusedOp->getResults());
448  return success();
449  }
450 
451 private:
452  // Helper to add argument and record the mapping.
453  static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454  mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
455  }
456 };
457 
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
459 // should really not be used to convert between sparse encodings. Since
460 // the pattern currently appears as a result of some prior rewriting
461 // we make an attempt to repair very obvious cases.
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464 public:
466 
467  LogicalResult matchAndRewrite(tensor::CastOp op,
468  PatternRewriter &rewriter) const override {
469  Type srcType = op.getSource().getType();
470  Type dstType = op.getDest().getType();
471  // A nop cast simply folds away.
472  if (srcType == dstType) {
473  rewriter.replaceOp(op, op->getResults());
474  return success();
475  }
476  // See if a sparsity changing cast can be fused into producer.
477  if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
478  if (Operation *def = op.getSource().getDefiningOp()) {
479  if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
480  rewriter.modifyOpInPlace(def, [&]() {
481  def->getResult(0).setType(op->getResultTypes()[0]);
482  });
483  rewriter.replaceOp(op, def->getResult(0));
484  return success();
485  }
486  }
487  }
488  // Repair tensor casts with at least one sparse operand into the
489  // the properly supported sparse_tensor.convert.
490  if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491  rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492  return success();
493  }
494  // Fail otherwise.
495  return failure();
496  }
497 };
498 
499 /// Rewrites a sequence of operations for sparse tensor selections in to
500 /// semi-ring operations such that they can be compiled correctly by the
501 /// sparsifier. E.g., transforming the following sequence
502 ///
503 /// %sel = arith.select %cond, %sp1, %sp2
504 ///
505 /// to
506 ///
507 /// %sel = binary %sp1, %sp2:
508 /// both (%l, %r) {yield select %cond, %l, %r}
509 /// left (%l) {yield select %cond, %l, 0}
510 /// right (%r) {yield select %cond, 0, %r}
511 ///
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
514 /// tri-nary operation.
515 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516 public:
518  LogicalResult matchAndRewrite(GenericOp op,
519  PatternRewriter &rewriter) const override {
520  // Rejects non sparse kernels.
521  if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522  return failure();
523 
524  Location loc = op.getLoc();
526  for (Operation &inst : *op.getBody()) {
527  // Matches pattern.
528  auto matched = isRewritablePattern(op, &inst);
529  if (!matched.has_value())
530  continue;
531 
532  rewriter.setInsertionPoint(&inst);
533  auto [c, t, f] = matched.value();
534  assert(t.getType() == f.getType());
535  auto selTp = t.getType();
536  auto c0 = constantZero(rewriter, loc, selTp);
537  auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538  // Initializes all the blocks.
539  rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540  {t.getLoc(), f.getLoc()});
541  rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542  rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543 
544  for (auto *r : binOp.getRegions()) {
545  Block *b = &r->front();
546  rewriter.setInsertionPointToStart(b);
547 
548  IRMapping irMap;
549  // Clones the cmp operations into the region to make the binary op
550  // admissible.
551  Value newC = c;
552  if (auto *def = c.getDefiningOp())
553  newC = rewriter.clone(*def, irMap)->getResult(0);
554 
555  irMap.map(c, newC);
556  if (r == &binOp.getLeftRegion()) {
557  irMap.map(t, b->getArgument(0));
558  irMap.map(f, c0);
559  } else if (r == &binOp.getRightRegion()) {
560  irMap.map(t, c0);
561  irMap.map(f, b->getArgument(0));
562  } else {
563  irMap.map(t, b->getArgument(0));
564  irMap.map(f, b->getArgument(1));
565  }
566  auto y = rewriter.clone(inst, irMap)->getResult(0);
567  rewriter.create<sparse_tensor::YieldOp>(loc, y);
568  }
569 
570  // We successfully rewrited a operation. We can not do replacement here
571  // becuase it invalidate the iterator for the current loop to traverse
572  // the instructions.
573  semiRings.emplace_back(&inst, binOp);
574  }
575 
576  // Finalizes the replacement.
577  for (auto [sel, semi] : semiRings)
578  rewriter.replaceOp(sel, semi->getResults());
579 
580  return success(!semiRings.empty());
581  }
582 
583 private:
584  static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585  isRewritablePattern(GenericOp op, Operation *v) {
586  auto sel = dyn_cast<arith::SelectOp>(v);
587  if (!sel)
588  return std::nullopt;
589 
590  auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591  auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592  // TODO: For simplicity, we only handle cases where both true/false value
593  // are directly loaded the input tensor. We can probably admit more cases
594  // in theory.
595  if (!tVal || !fVal)
596  return std::nullopt;
597 
598  // Helper lambda to determine whether the value is loaded from a dense input
599  // or is a loop invariant.
600  auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601  if (auto bArg = dyn_cast<BlockArgument>(v);
602  bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603  return true;
604  // If the value is defined outside the loop, it is a loop invariant.
605  return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606  };
607 
608  // If the condition value is load directly from a dense tensor or
609  // loop-invariants, we can sparsify the kernel.
610  auto cond = sel.getCondition();
611  if (isValFromDenseInputOrInvariant(cond))
612  return std::make_tuple(cond, tVal, fVal);
613 
614  Value cmpL, cmpR;
615  if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616  matchers::m_Any(&cmpR))) ||
617  matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618  matchers::m_Any(&cmpR)))) {
619  // TODO: we can do it recursively to check whether all the leaf values are
620  // loaded from dense tensors or are loop invariants.
621  if (isValFromDenseInputOrInvariant(cmpL) ||
622  isValFromDenseInputOrInvariant(cmpR))
623  return std::make_tuple(cond, tVal, fVal);
624  }
625 
626  return std::nullopt;
627  };
628 };
629 
630 /// Rewrites a sparse reduction that would not sparsify directly since
631 /// doing so would only iterate over the stored elements, ignoring the
632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633 /// (note that reductions like add/sub/or/xor can directly be sparsified
634 /// since the implicit zeros do not contribute to the final result).
635 /// Note that prod/and are still included since, even though they often
636 /// are nullified in sparse data, they may still occur for special
637 /// situations in which e.g. some rows in a sparse matrix are fully
638 /// dense. For min/max, including the implicit zeros is a much more
639 /// common situation.
640 ///
641 /// TODO: this essentially "densifies" the operation; we want to implement
642 /// this much more efficiently by performing the reduction over the
643 /// stored values, and feed in the zero once if there were *any*
644 /// implicit zeros as well; but for now, at least we provide
645 /// the functionality
646 ///
647 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648 public:
650 
651  LogicalResult matchAndRewrite(GenericOp op,
652  PatternRewriter &rewriter) const override {
653  // Reject non-reductions.
654  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655  op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656  return failure();
657  auto *inp = op.getDpsInputOperand(0);
658  auto *init = op.getDpsInitOperand(0);
659  if (!isSparseTensor(inp))
660  return failure();
661  // Look for direct x = x OP y for semi-ring ready reductions.
662  auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663  .getOperand(0)
664  .getDefiningOp();
665  if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666  arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667  arith::MaxUIOp>(red))
668  return failure();
669  Value s0 = op.getBlock()->getArgument(0);
670  Value s1 = op.getBlock()->getArgument(1);
671  if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672  (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673  return failure();
674  // Identity.
675  Location loc = op.getLoc();
676  Value identity =
677  rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
678  // Unary {
679  // present -> value
680  // absent -> zero.
681  // }
682  Type rtp = s0.getType();
683  rewriter.setInsertionPointToStart(&op.getRegion().front());
684  auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
685  Block *present =
686  rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687  rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688  rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689  rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690  rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691  auto zero =
692  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
693  rewriter.create<sparse_tensor::YieldOp>(loc, zero);
694  rewriter.setInsertionPointAfter(semiring);
695  // CustomReduce {
696  // x = x REDUC y, identity
697  // }
698  auto custom = rewriter.create<sparse_tensor::ReduceOp>(
699  loc, rtp, semiring.getResult(), s1, identity);
700  Block *region =
701  rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702  rewriter.setInsertionPointToStart(&custom.getRegion().front());
703  IRMapping irMap;
704  irMap.map(red->getOperand(0), region->getArgument(0));
705  irMap.map(red->getOperand(1), region->getArgument(1));
706  auto *cloned = rewriter.clone(*red, irMap);
707  rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
708  rewriter.setInsertionPointAfter(custom);
709  rewriter.replaceOp(red, custom.getResult());
710  return success();
711  }
712 };
713 
714 /// Sparse rewriting rule for the print operator. This operation is mainly used
715 /// for debugging and testing. As such, it lowers to the vector.print operation
716 /// which only require very light-weight runtime support.
717 struct PrintRewriter : public OpRewritePattern<PrintOp> {
718 public:
720  LogicalResult matchAndRewrite(PrintOp op,
721  PatternRewriter &rewriter) const override {
722  Location loc = op.getLoc();
723  auto tensor = op.getTensor();
724  auto stt = getSparseTensorType(tensor);
725  // Header with NSE.
726  auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
727  rewriter.create<vector::PrintOp>(
728  loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729  rewriter.create<vector::PrintOp>(loc, nse);
730  // Print run-time contents for dim/lvl sizes.
731  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
732  printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
733  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
734  printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
735  // Use the "codegen" foreach loop construct to iterate over
736  // all typical sparse tensor components for printing.
737  foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
738  &stt](Type, FieldIndex,
740  Level l, LevelType) {
741  switch (kind) {
742  case SparseTensorFieldKind::StorageSpec: {
743  break;
744  }
745  case SparseTensorFieldKind::PosMemRef: {
746  auto lvl = constantIndex(rewriter, loc, l);
747  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
748  rewriter.create<vector::PrintOp>(
749  loc, lvl, vector::PrintPunctuation::NoPunctuation);
750  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
751  auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
752  printContents(rewriter, loc, pos);
753  break;
754  }
755  case SparseTensorFieldKind::CrdMemRef: {
756  auto lvl = constantIndex(rewriter, loc, l);
757  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
758  rewriter.create<vector::PrintOp>(
759  loc, lvl, vector::PrintPunctuation::NoPunctuation);
760  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
761  Value crd = nullptr;
762  // For COO AoS storage, we want to print a single, linear view of
763  // the full coordinate storage at this level. For any other storage,
764  // we show the coordinate storage for every indivual level.
765  if (stt.getAoSCOOStart() == l)
766  crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
767  else
768  crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
769  printContents(rewriter, loc, crd);
770  break;
771  }
772  case SparseTensorFieldKind::ValMemRef: {
773  rewriter.create<vector::PrintOp>(loc,
774  rewriter.getStringAttr("values : "));
775  auto val = rewriter.create<ToValuesOp>(loc, tensor);
776  printContents(rewriter, loc, val);
777  break;
778  }
779  }
780  return true;
781  });
782  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
783  rewriter.eraseOp(op);
784  return success();
785  }
786 
787 private:
788  // Helper to print contents of a single memref. For "push_back" vectors,
789  // we assume that the previous getters for pos/crd/val have added a
790  // slice-to-size view to make sure we just print the size and not the
791  // full capacity.
792  //
793  // Generates code to print (1-dim or higher):
794  // ( a0, a1, ... )
795  static void printContents(PatternRewriter &rewriter, Location loc,
796  Value vec) {
797  auto shape = cast<ShapedType>(vec.getType()).getShape();
798  SmallVector<Value> idxs;
799  printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
801  }
802 
803  // Helper to the helper.
804  static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805  Value vec, unsigned i, ArrayRef<int64_t> shape,
806  SmallVectorImpl<Value> &idxs) {
807  // Open bracket.
808  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
809  // Generate for loop.
810  auto zero = constantIndex(rewriter, loc, 0);
811  auto index = constantIndex(rewriter, loc, i);
812  auto size = rewriter.create<memref::DimOp>(loc, vec, index);
813  auto step = constantIndex(rewriter, loc, 1);
814  auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815  idxs.push_back(forOp.getInductionVar());
816  rewriter.setInsertionPointToStart(forOp.getBody());
817  if (i < shape.size() - 1) {
818  // Enter deeper loop nest.
819  printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
820  } else {
821  // Actual contents printing.
822  auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823  if (llvm::isa<ComplexType>(val.getType())) {
824  // Since the vector dialect does not support complex types in any op,
825  // we split those into (real, imag) pairs here.
826  Value real = rewriter.create<complex::ReOp>(loc, val);
827  Value imag = rewriter.create<complex::ImOp>(loc, val);
828  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829  rewriter.create<vector::PrintOp>(loc, real,
830  vector::PrintPunctuation::Comma);
831  rewriter.create<vector::PrintOp>(loc, imag,
832  vector::PrintPunctuation::Close);
833  } else {
834  rewriter.create<vector::PrintOp>(
835  loc, val, vector::PrintPunctuation::NoPunctuation);
836  }
837  // Terminating comma (except at end).
838  auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
839  Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
840  bound, size);
841  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
842  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
843  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
844  }
845  idxs.pop_back();
846  rewriter.setInsertionPointAfter(forOp);
847  // Close bracket.
848  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
849  }
850 
851  // Helper method to print run-time lvl/dim sizes.
852  static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
853  unsigned size, bool isDim) {
854  // Open bracket.
855  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
856  // Print unrolled contents (dimop requires constant value).
857  for (unsigned i = 0; i < size; i++) {
858  auto idx = constantIndex(rewriter, loc, i);
859  Value val;
860  if (isDim)
861  val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
862  else
863  val = rewriter.create<LvlOp>(loc, tensor, idx);
864  rewriter.create<vector::PrintOp>(
865  loc, val,
866  i != size - 1 ? vector::PrintPunctuation::Comma
867  : vector::PrintPunctuation::NoPunctuation);
868  }
869  // Close bracket and end of line.
870  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
872  }
873 };
874 
875 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
876 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
877 public:
879 
880  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
881  PatternRewriter &rewriter) const override {
882  Location loc = op.getLoc();
883  Value srcTensor = op.getSource();
884  const auto srcTp = getSparseTensorType(srcTensor);
885  const auto dstTp = getSparseTensorType(op.getResult());
886 
887  if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
888  !dstTp.hasStaticDimShape())
889  return failure();
890 
891  SmallVector<Value> srcSizes;
892  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
893  SmallVector<Value> dstSizes;
894  for (Dimension d : dstTp.getDimShape())
895  dstSizes.push_back(constantIndex(rewriter, loc, d));
896 
897  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
898  // Only need an unordered COO buffer if input and output are not sorted
899  // in the same way.
900  Type bufferTp = getBufferType(
901  dstTp.withoutDimToLvl(),
902  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
903  SmallVector<Value> dynSizes;
904  Value buffer = rewriter
905  .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
906  nnz, Attribute())
907  .getResult();
908 
909  // Convert src coordinates to dst coordinates by first collapsing it to 1D
910  // and then expand it to the match the rank of the destination tensor.
911  // Implemented as follows:
912  // foreach srcCoords %srcTensor
913  // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
914  // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
915  // insert expandedCoords, %buffer
916  //
917  // followed by an optional
918  // %t = sparse_tensor.cast %tmp
919  // depending on whether the input/output are sorted in the same way.
920  const auto encSrc = srcTp.getEncoding();
921  ForeachOp foreachOp = rewriter.create<ForeachOp>(
922  loc, srcTensor, buffer,
923  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
924  ValueRange reduc) {
925  const Dimension srcRank = srcTp.getDimRank();
926  SmallVector<Value> srcDcvs;
927  srcDcvs.reserve(srcRank);
928  for (Dimension d = 0; d < srcRank; d++) {
929  Level lvl = toLvl(encSrc, d);
930  srcDcvs.push_back(srcLcvs[lvl]);
931  }
932 
933  Value collapseSize = constantIndex(builder, loc, 1);
934  for (Dimension d = 0; d < srcRank; d++)
935  collapseSize =
936  builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
937  SmallVector<Value, 1> collapsedSizes = {collapseSize};
938 
939  ReassociationIndices collapseIdx;
940  for (Dimension i = 0; i < srcRank; i++)
941  collapseIdx.push_back(i);
942  SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
943  SmallVector<Value, 1> collapsedDcvs;
944  reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
945  collapsedSizes, collapsedDcvs);
946 
947  ReassociationIndices expandIdx;
948  for (Dimension i = 0; i < dstTp.getDimRank(); i++)
949  expandIdx.push_back(i);
950  SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
951  SmallVector<Value> dstDcvs;
952  reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
953  dstSizes, dstDcvs);
954 
955  auto t =
956  builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
957  builder.create<sparse_tensor::YieldOp>(loc, t);
958  });
959 
960  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
961  if (bufferTp != dstTp) {
962  auto dstRTT = dstTp.getRankedTensorType();
963  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
964  rewriter.create<DeallocTensorOp>(loc, t);
965  t = converted;
966  }
967  rewriter.replaceOp(op, t);
968  return success();
969  }
970 };
971 
972 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
973 template <typename ReshapeOp>
974 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
975 public:
977 
978  LogicalResult matchAndRewrite(ReshapeOp op,
979  PatternRewriter &rewriter) const override {
980  Location loc = op.getLoc();
981  Value srcTensor = op.getSrc();
982  const auto srcTp = getSparseTensorType(srcTensor);
983  const auto dstTp = getSparseTensorType(op.getResult());
984  if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
985  return failure();
986 
987  // Generate code to represent the static dimension constants or compute
988  // the dynamic dimension values.
989  SmallVector<Value> srcSizes;
990  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
991  SmallVector<Value> dstSizes;
992  SmallVector<Value> dstDynSizes;
993  if (dstTp.hasStaticDimShape()) {
994  for (Dimension d : dstTp.getDimShape())
995  dstSizes.push_back(constantIndex(rewriter, loc, d));
996  } else {
997  ArrayRef<Size> dstShape = dstTp.getDimShape();
998  genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
999  op.getReassociationIndices());
1000  for (auto [idx, shape] : llvm::enumerate(dstShape)) {
1001  if (shape == ShapedType::kDynamic)
1002  dstDynSizes.push_back(dstSizes[idx]);
1003  }
1004  }
1005  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
1006  // Only need a unordered COO buffer if input and output are not sorted
1007  // in the same way.
1008  Type bufferTp = getBufferType(
1009  dstTp.withoutDimToLvl(),
1010  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1011 
1012  Value buffer =
1013  rewriter
1014  .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
1015  /*sizeHint=*/nnz, Attribute())
1016  .getResult();
1017 
1018  // Implement the sparse2sparse reshape as follows:
1019  // foreach srcCoords %srcTensor
1020  // insert reshapeCvs(srcCoords), %buffer
1021  //
1022  // followed by an optional
1023  // %t = sparse_tensor.cast %tmp
1024  // depending on whether the input/output are sorted in the same way.
1025  const auto encSrc = srcTp.getEncoding();
1026  ForeachOp foreachOp = rewriter.create<ForeachOp>(
1027  loc, srcTensor, buffer,
1028  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1029  ValueRange reduc) {
1030  const Dimension dimRank = srcTp.getDimRank();
1031  SmallVector<Value> srcDcvs;
1032  srcDcvs.reserve(dimRank);
1033  for (Dimension d = 0; d < dimRank; d++) {
1034  Level lvl = toLvl(encSrc, d);
1035  srcDcvs.push_back(srcLcvs[lvl]);
1036  }
1037  SmallVector<Value> dstDcvs;
1038  reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1039  srcDcvs, dstSizes, dstDcvs);
1040  auto t =
1041  builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1042  builder.create<sparse_tensor::YieldOp>(loc, t);
1043  });
1044 
1045  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
1046  if (bufferTp != dstTp) {
1047  auto dstRTT = dstTp.getRankedTensorType();
1048  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
1049  rewriter.create<DeallocTensorOp>(loc, t);
1050  t = converted;
1051  }
1052  rewriter.replaceOp(op, t);
1053  return success();
1054  }
1055 };
1056 
1057 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1058 /// operator.
1059 template <typename ReshapeOp>
1060 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1061 public:
1063 
1064  LogicalResult matchAndRewrite(ReshapeOp op,
1065  PatternRewriter &rewriter) const override {
1066  Location loc = op->getLoc();
1067  auto encDst = getSparseTensorEncoding(op.getResult().getType());
1068  auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1069  // Since a pure dense expansion is very cheap (change of view), for
1070  // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1071  // conversion from the reshape operation itself.
1072  // All other cases are handled elsewhere.
1073  if (encDst && encSrc) {
1074  return failure();
1075  }
1076  if (encSrc) {
1077  auto rtp = getRankedTensorType(op.getSrc());
1078  auto denseTp =
1079  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1080  auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
1081  rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1082  return success();
1083  }
1084  if (encDst) {
1085  auto rtp = getRankedTensorType(op.getResult());
1086  auto denseTp =
1087  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1088  ReshapeOp reshape;
1089  if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1090  reshape = rewriter.create<ReshapeOp>(
1091  loc, denseTp, op.getSrc(), op.getReassociation(),
1092  op.getOutputShape(), op.getStaticOutputShape());
1093  } else {
1094  reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
1095  op.getReassociation());
1096  }
1097  Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
1098  rewriter.replaceOp(op, convert);
1099  return success();
1100  }
1101  return failure();
1102  }
1103 };
1104 
1105 // A trivial wrapper to help generate different operations for dense/sparse
1106 // tensors.
1107 struct TensorLike {
1108  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1109  ValueRange sizes) {
1110  SmallVector<Value> dynSzs;
1111  getDynamicSizes(rtt, sizes, dynSzs);
1112 
1113  val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
1114  if (!isSparse()) {
1115  Value c0 = constantZero(builder, loc, rtt.getElementType());
1116  val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
1117  }
1118  }
1119 
1120  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1121  val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1122  }
1123 
1124  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1125  if (isSparse())
1126  return builder.create<LoadOp>(loc, val, true);
1127  return val;
1128  }
1129 
1130  bool isSparse() const {
1131  return getSparseTensorEncoding(val.getType()) != nullptr;
1132  }
1133 
1134  Value val;
1135 };
1136 
1137 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1139  LogicalResult matchAndRewrite(tensor::DimOp op,
1140  PatternRewriter &rewriter) const override {
1141  std::optional<int64_t> dim = op.getConstantIndex();
1142  auto stt = getSparseTensorType(op.getSource());
1143  if (!dim || !stt.hasEncoding())
1144  return failure();
1145 
1146  if (stt.isPermutation()) {
1147  rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1148  toLvl(stt.getEncoding(), *dim));
1149  return success();
1150  }
1151 
1152  // Non-permutation dim2lvl/lvl2dim maps.
1153  // Compute as follows:
1154  // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1155  // Note that it is not the most efficient way (but a more general one) for
1156  // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1157  // computed simply by lvl_size * block_size.
1158  Location loc = op.getLoc();
1159  SmallVector<Value> maxLvlCrds;
1160  for (Level l = 0; l < stt.getLvlRank(); l++) {
1161  Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1162  Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1163  loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1164  maxLvlCrds.push_back(maxLvlCrd);
1165  }
1166 
1167  AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
1168  Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1169  op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
1170  maxLvlCrds);
1171 
1172  Value dimSz = rewriter.create<arith::AddIOp>(
1173  loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1174  rewriter.replaceOp(op, dimSz);
1175  return success();
1176  }
1177 };
1178 
1179 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1181  LogicalResult matchAndRewrite(ConcatenateOp op,
1182  PatternRewriter &rewriter) const override {
1183  if (op.needsExtraSort())
1184  op.emitError("ConcatenateOp not staged");
1185 
1186  const Location loc = op.getLoc();
1187  const auto dstTp = getSparseTensorType(op);
1188  const Dimension conDim = op.getDimension();
1189  SmallVector<Value> sizes;
1190  concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1191 
1192  // %t = concatenate %s1, %s2, %s3 {dim = 1}
1193  // ==>
1194  // if (isSparseDst)
1195  // if (allDense)
1196  // %tmp = bufferization.alloc_tensor dstTp
1197  // else
1198  // %tmp = bufferization.alloc_tensor : unordered COO
1199  // else
1200  // %tmp = memref.alloc : dense tensor
1201  // foreach in %s1 : insert d0, d1, %tmp
1202  // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1203  // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1204 
1205  TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1206  Value offset = constantIndex(rewriter, loc, 0);
1207  Value iterArg = dstBuf.val;
1208 
1209  ForeachOp foreachOp;
1210  for (Value input : op.getInputs()) {
1211  // Builds a for op for each input tensor to append new values into the
1212  // output tensor.
1213  foreachOp = rewriter.create<ForeachOp>(
1214  loc, input, iterArg,
1215  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1216  ValueRange reduc) {
1217  SmallVector<Value> offDimCrd(dcvs);
1218  offDimCrd[conDim] =
1219  builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1220 
1221  // Enters foreach, updates the SSA chain.
1222  dstBuf.val = reduc.front();
1223  if (!dstTp.isAllDense()) {
1224  Value cond = genIsNonzero(builder, loc, v);
1225  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1226  /*else*/ true);
1227  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1228  builder.create<scf::YieldOp>(loc, dstBuf.val);
1229 
1230  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1231  dstBuf.insert(builder, loc, v, offDimCrd);
1232  builder.create<scf::YieldOp>(loc, dstBuf.val);
1233 
1234  // Exits the ifOp, update the sparse tensor SSA value.
1235  builder.setInsertionPointAfter(ifOp);
1236  dstBuf.val = ifOp.getResult(0);
1237  } else {
1238  dstBuf.insert(builder, loc, v, offDimCrd);
1239  }
1240  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1241  });
1242  // Accumulates the offset. Note that only static-shaped inputs are allowed
1243  // by concatenate op verifier, which saves us from computing the offset
1244  // dynamically.
1245  const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1246  assert(!ShapedType::isDynamic(sz));
1247  offset = rewriter.create<arith::AddIOp>(loc, offset,
1248  constantIndex(rewriter, loc, sz));
1249  iterArg = foreachOp.getResult(0);
1250  dstBuf.val = iterArg;
1251  }
1252 
1253  dstBuf.val = iterArg;
1254  Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1255  rewriter.replaceOp(op, ret);
1256  return success();
1257  }
1258 };
1259 
1260 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1262  LogicalResult matchAndRewrite(ConvertOp op,
1263  PatternRewriter &rewriter) const override {
1264  if (op.needsExtraSort())
1265  return op.emitError("ConvertOp not staged.");
1266 
1267  // TODO: Maybe we want a different operation for this too.
1268  auto encDst = getSparseTensorEncoding(op.getType());
1269  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1270  if (encDst && encSrc && !encSrc.isSlice() &&
1271  encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1272  // Trivial tensor conversion and simple element type conversion is handled
1273  // in codegen.
1274  return failure();
1275  }
1276 
1277  Location loc = op.getLoc();
1278  Value src = op.getSource();
1279 
1280  SparseTensorType srcStt = getSparseTensorType(op.getSource());
1281  SparseTensorType dstStt = getSparseTensorType(op.getDest());
1282 
1283  bool fromSparseConst = false;
1284  if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1285  if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1286  fromSparseConst = true;
1287 
1288  const AffineMapAttr foreachOrder =
1289  (!dstStt.isIdentity() && fromSparseConst)
1291  : nullptr;
1292 
1293  bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1294 
1295  SmallVector<Value> sizes;
1296  sizesFromSrc(rewriter, sizes, loc, src);
1297  ValueRange vs;
1298  TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1299 
1300  auto foreachOp = rewriter.create<ForeachOp>(
1301  loc, src, dstBuf.val, foreachOrder,
1302  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1303  ValueRange reduc) {
1304  // Enters the loop, update the SSA value for insertion chain.
1305  dstBuf.val = reduc.front();
1306  if (!skipZeroCheck) {
1307  Value cond = genIsNonzero(builder, loc, v);
1308  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1309  /*else*/ true);
1310  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1311  builder.create<scf::YieldOp>(loc, dstBuf.val);
1312 
1313  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1314  dstBuf.insert(builder, loc, v, dcvs);
1315  builder.create<scf::YieldOp>(loc, dstBuf.val);
1316 
1317  // Exits the ifOp, update the sparse tensor SSA value.
1318  builder.setInsertionPointAfter(ifOp);
1319  dstBuf.val = ifOp.getResult(0);
1320  } else {
1321  dstBuf.insert(builder, loc, v, dcvs);
1322  }
1323  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1324  });
1325 
1326  rewriter.setInsertionPointAfter(foreachOp);
1327 
1328  // Exits the for loop, links the SSA chain.
1329  dstBuf.val = foreachOp.getResult(0);
1330 
1331  Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1332  rewriter.replaceOp(op, ret);
1333  return success();
1334  }
1335 };
1336 
1337 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1339  LogicalResult matchAndRewrite(CrdTranslateOp op,
1340  PatternRewriter &rewriter) const override {
1341  AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1342  ? op.getEncoder().getDimToLvl()
1343  : op.getEncoder().getLvlToDim();
1344 
1345  SmallVector<Value> outCrds;
1346  for (AffineExpr result : map.getResults()) {
1347  // TODO: we should probably expand the affine map to IR using our own
1348  // rules, since affine.apply assume signed value, while the cooridinates
1349  // we provided must always be signless.
1350  Value trans = rewriter.create<affine::AffineApplyOp>(
1351  op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1352  op.getInCrds());
1353  outCrds.push_back(trans);
1354  }
1355  rewriter.replaceOp(op, outCrds);
1356  return success();
1357  }
1358 };
1359 
1360 /// Sparse rewriting rule for the foreach operator.
1361 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1362 public:
1364 
1365  LogicalResult matchAndRewrite(ForeachOp op,
1366  PatternRewriter &rewriter) const override {
1367 
1368  auto loc = op.getLoc();
1369  Value input = op.getTensor();
1370  SmallVector<Value> reduc = op.getInitArgs();
1371  const auto stt = getSparseTensorType(input);
1372  const Level lvlRank = stt.getLvlRank();
1373 
1374  // Special-case: for each over a sparse constant uses its own rewriting
1375  // rule.
1376  if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1377  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1378  return genForeachOnSparseConstant(op, rewriter, attr);
1379  }
1380  }
1381 
1382  // Otherwise, use loop emitter to generate loops.
1383  const auto enc = stt.getEncoding();
1384 
1385  // 1. Generates loop for the sparse input.
1386  LoopEmitter loopEmitter(
1387  ValueRange{input},
1388  StringAttr::get(getContext(), ForeachOp::getOperationName()));
1389  loopEmitter.initializeLoopEmit(rewriter, loc);
1390  for (Level l = 0; l < lvlRank; l++) {
1391  // TODO: provide utility function for loop sequences that only contains
1392  // one for loop?
1393  const SmallVector<TensorLevel, 1> tidLvls{
1394  loopEmitter.makeTensorLevel(0, l)};
1395  loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1396  // Note that reduc will be taken care of by loop emitter and get updated
1397  // in place.
1398  loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1399  reduc);
1400  }
1401 
1402  SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1403  if (op.getOrder()) {
1404  // TODO: Support it so that we can do direct conversion from CSR->BSR.
1405  llvm_unreachable(
1406  "Level order not yet implemented on non-constant input tensors.");
1407  }
1408 
1409  Value vals = loopEmitter.getValBuffer()[0];
1410  SmallVector<Value> pos = loopEmitter.getValPosits(0);
1411  // Loads the value from sparse tensor using position-index;
1412  // loads the value from dense tensor using coords.
1413  Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1414  : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1415 
1416  // 2. Inline the block in the foreach operator.
1417  Block *srcBlock = op.getBody();
1418 
1419  // Remap coordinates.
1420  SmallVector<Value> args =
1421  enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1422 
1423  // Remap value.
1424  args.push_back(val);
1425  // Remap reduction variables.
1426  args.append(reduc);
1427 
1428  // Remove sparse_tensor.yield.
1429  SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1430  rewriter.eraseOp(srcBlock->getTerminator());
1431 
1432  Operation &last = rewriter.getBlock()->back();
1433  if (llvm::isa<scf::YieldOp>(last)) {
1434  // Because `scf.for` inserts an implicit yield op when there is no
1435  // reduction variable upon creation, we reset the insertion point such
1436  // that the block is inlined before *before* the yield op.
1437  rewriter.setInsertionPoint(&last);
1438  }
1439 
1440  rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1441  rewriter.getInsertionPoint(), args);
1442  rewriter.setInsertionPointToEnd(rewriter.getBlock());
1443  for (Level l = 0; l < lvlRank; l++) {
1444  // Link the reduction chain. Note that loop emitter update the reducValue
1445  // in place.
1446  loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1447  loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1448  }
1449 
1450  // Replace the foreach operator with the value returned by the outtermost
1451  // for loop.
1452  rewriter.replaceOp(op, reducValue);
1453  return success();
1454  }
1455 };
1456 
1457 /// Sparse rewriting rule for the new operator.
1458 struct NewRewriter : public OpRewritePattern<NewOp> {
1460  LogicalResult matchAndRewrite(NewOp op,
1461  PatternRewriter &rewriter) const override {
1462  Location loc = op.getLoc();
1463  auto stt = getSparseTensorType(op.getResult());
1464  if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1465  return failure();
1466 
1467  // Implement the NewOp as follows:
1468  // %orderedCoo = sparse_tensor.new %filename
1469  // %t = sparse_tensor.convert %orderedCoo
1470  // with enveloping reinterpreted_map ops for non-permutations.
1471  RankedTensorType dstTp = stt.getRankedTensorType();
1472  RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1473  Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1474  Value convert = cooTensor;
1475  auto enc = stt.getEncoding();
1476  if (!stt.isPermutation()) { // demap coo, demap dstTp
1477  auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1478  convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1479  dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1480  }
1481  convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1482  if (!stt.isPermutation()) // remap to original enc
1483  convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1484  rewriter.replaceOp(op, convert);
1485 
1486  // Release the temporary ordered COO tensor.
1487  rewriter.setInsertionPointAfterValue(convert);
1488  rewriter.create<DeallocTensorOp>(loc, cooTensor);
1489 
1490  return success();
1491  }
1492 };
1493 
1494 /// Sparse rewriting rule for the out operator.
1495 struct OutRewriter : public OpRewritePattern<OutOp> {
1497  LogicalResult matchAndRewrite(OutOp op,
1498  PatternRewriter &rewriter) const override {
1499  Location loc = op.getLoc();
1500  // Calculate NNZ.
1501  Value src = op.getTensor();
1502  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1503 
1504  // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1505  const auto srcTp = getSparseTensorType(src);
1506  const Dimension dimRank = srcTp.getDimRank();
1507  Type indexTp = rewriter.getIndexType();
1508  Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1509 
1510  // Generate code to calculate dimension size values and store the values to
1511  // the buffer.
1512  SmallVector<Value> dims;
1513  sizesForTensor(rewriter, dims, loc, srcTp, src);
1514  for (Dimension d = 0; d < dimRank; d++) {
1515  rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1516  constantIndex(rewriter, loc, d));
1517  }
1518 
1519  // Create a sparse tensor writer and output meta data.
1520  Type opaqueTp = getOpaquePointerType(rewriter);
1521  Value writer =
1522  createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1523  {op.getDest()}, EmitCInterface::Off)
1524  .getResult(0);
1525  Value rankValue = constantIndex(rewriter, loc, dimRank);
1526  createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1527  {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1528 
1529  Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1530  Type eltTp = srcTp.getElementType();
1531  SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1532  primaryTypeFunctionSuffix(eltTp)};
1533  Value value = genAllocaScalar(rewriter, loc, eltTp);
1534  ModuleOp module = op->getParentOfType<ModuleOp>();
1535 
1536  // For each element in the source tensor, output the element.
1537  rewriter.create<ForeachOp>(
1538  loc, src, std::nullopt,
1539  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1540  ValueRange reduc) {
1541  for (Dimension d = 0; d < dimRank; d++) {
1542  rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1543  constantIndex(builder, loc, d));
1544  }
1545  rewriter.create<memref::StoreOp>(loc, v, value);
1546  SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1547  FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1548  EmitCInterface::On);
1549  builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1550  builder.create<sparse_tensor::YieldOp>(loc);
1551  });
1552 
1553  // Release the writer.
1554  createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1555  EmitCInterface::Off);
1556 
1557  rewriter.eraseOp(op);
1558  return success();
1559  }
1560 };
1561 
1562 } // namespace
1563 
1564 //===---------------------------------------------------------------------===//
1565 // Methods that add patterns described in this file to a pattern list.
1566 //===---------------------------------------------------------------------===//
1567 
1569  patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1570  FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1571  GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1572  patterns.getContext());
1573 }
1574 
1576  bool enableRT,
1577  bool enableConvert) {
1578  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1579  ReshapeRewriter<tensor::CollapseShapeOp>,
1580  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1581  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1582  SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1583  patterns.getContext());
1584 
1585  if (enableConvert)
1586  patterns.add<DirectConvertRewriter>(patterns.getContext());
1587  if (!enableRT)
1588  patterns.add<NewRewriter>(patterns.getContext());
1589 }
1590 
1592  // Run CrdTranslateRewriter later in the pipeline so that operation can be
1593  // folded before lowering to affine.apply
1594  patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1595 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:381
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:394
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
auto m_Any()
Definition: Matchers.h:448
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:323
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:127
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:153
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:123
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:340
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238