MLIR  22.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
15 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
32 
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
37 
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
41 
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45 }
46 
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49  auto enc = getSparseTensorEncoding(v.getType());
50  return enc && !llvm::all_of(enc.getLvlTypes(),
51  [](auto lt) { return lt == LevelFormat::Dense; });
52 }
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54 
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57  Value val = op->get();
58  // Check allocation, with zero alloc when required.
59  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60  Value copy = alloc.getCopy();
61  if (isZero)
62  return copy && isZeroValue(copy);
63  return !copy;
64  }
65  // Check for empty tensor materialization.
66  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67  return !isZero;
68  // Last resort for zero alloc: the whole value is zero.
69  return isZero && isZeroValue(val);
70 }
71 
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77  // Both scalar input arguments used exactly once.
78  Value s1 = op.getBlock()->getArgument(0);
79  Value s2 = op.getBlock()->getArgument(1);
80  return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81  (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82  }
83  }
84  return false;
85 }
86 
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89  if (auto arg = dyn_cast<BlockArgument>(val))
90  return arg != x;
91  if (auto *def = val.getDefiningOp()) {
92  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93  return isMulChain(def->getOperand(0), x) &&
94  isMulChain(def->getOperand(1), x);
95  }
96  return false;
97 }
98 
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104  Value x = op.getBlock()->getArguments().back();
105  return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106  (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107  }
108  }
109  return false;
110 }
111 
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115  if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116  if (arg.getOwner()->getParentOp() == op) {
117  return isZeroValue(op->getOperand(arg.getArgNumber()));
118  }
119  }
120  return isZeroValue(yieldOp.getOperand(0));
121 }
122 
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126  Location loc, ShapedType stp, Value tensor) {
127  for (const auto &d : enumerate(stp.getShape())) {
128  Value dim;
129  if (d.value() == ShapedType::kDynamic)
130  dim = tensor::DimOp::create(builder, loc, tensor, d.index());
131  else
132  dim = constantIndex(builder, loc, d.value());
133  sizes.push_back(dim);
134  }
135 }
136 
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138  bool needTmpCOO) {
139  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140  : stt.getRankedTensorType();
141 }
142 
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147  SmallVectorImpl<Value> &dynSizes) {
148  for (const auto &d : enumerate(tp.getShape())) {
149  if (d.value() == ShapedType::kDynamic)
150  dynSizes.push_back(sizes[d.index()]);
151  }
152 }
153 
154 static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155  RewriterBase &rewriter,
156  SparseElementsAttr attr) {
157  auto loc = op.getLoc();
158  SmallVector<Value> reduc = op.getInitArgs();
159 
160  // Foreach on constant.
162  rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163  [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164  SmallVector<Value> args;
165  args.append(cvs.begin(), cvs.end());
166  args.push_back(v);
167  args.append(reduc);
168  // Clones the foreach op to get a copy of the loop body.
169  auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170  assert(args.size() == cloned.getBody()->getNumArguments());
171  Operation *yield = cloned.getBody()->getTerminator();
172  rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173  // clean up
174  rewriter.eraseOp(cloned);
175  reduc = yield->getOperands();
176  rewriter.eraseOp(yield);
177  });
178 
179  rewriter.replaceOp(op, reduc);
180  return success();
181 }
182 
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186  SmallVectorImpl<Value> &sizes, Location loc,
187  ShapedType dstTp, ValueRange srcs,
188  unsigned dim) {
189  auto dstShape = dstTp.getShape();
190  sizesFromSrc(builder, sizes, loc, srcs[0]);
191 
192  // Sum up on the `dim` if the dimension is dynamic.
193  if (dstShape[dim] != ShapedType::kDynamic) {
194  // Faithfully take the static size.
195  sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196  } else {
197  // Else, compute the shape dynamically.
198  for (const auto &src : srcs.drop_front()) {
199  Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200  // Sum up all the sizes.
201  sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz);
202  }
203  }
204 }
205 
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
209 
210 namespace {
211 
212 /// TODO: move it to tensor dialect instead.
213 ///
214 /// Fold `tensor.concat` and `tensor.extract_slice`
215 ///
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222 ///
223 /// Becomes
224 ///
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227  : public OpRewritePattern<tensor::ExtractSliceOp> {
229 
230  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231  PatternRewriter &rewriter) const override {
232  auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233  if (!concatOp)
234  return failure();
235 
236  Location loc = extractOp.getLoc();
237  int64_t dim = concatOp.getDim();
238  int64_t rank = extractOp.getResultType().getRank();
239 
240  SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241  SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242 
243  // Compute the partial sums for the slice offsets.
244  AffineExpr sum = rewriter.getAffineDimExpr(0);
245  SmallVector<AffineExpr> partialSums = {sum};
246  SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247  for (auto [idx, input] :
248  llvm::enumerate(concatOp.getInputs().drop_back())) {
249  sum = sum + rewriter.getAffineDimExpr(idx + 1);
250  partialSums.push_back(sum);
251  offsetStrides.push_back(
252  rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253  }
254  auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255  partialSums, rewriter.getContext());
256  SmallVector<OpFoldResult> dimOffsets =
258  rewriter, loc, partialSumMap, offsetStrides);
259 
260  auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261  for (auto [l, r] : llvm::zip(lhs, rhs)) {
262  std::optional<int64_t> staticVal = getConstantIntValue(l);
263  if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264  return false;
265  }
266  return lhs.size() == rhs.size();
267  };
268 
269  for (auto [i, input, offset] :
270  llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271  SmallVector<OpFoldResult> srcSizes =
272  tensor::getMixedSizes(rewriter, loc, input);
273  srcOffsets[dim] = offset;
274 
275  SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276  SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277  SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278 
279  if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280  allEqual(srcStrides, dstStrides)) {
281  Value operand = concatOp.getOperand(i);
282  if (operand.getType() == extractOp.getResultType())
283  rewriter.replaceOp(extractOp, operand);
284  break;
285  }
286  }
287 
288  return success();
289  }
290 };
291 
292 /// Rewriting rule that fuses sparse_tensor.convert into producer.
293 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294 public:
296 
297  LogicalResult matchAndRewrite(ConvertOp op,
298  PatternRewriter &rewriter) const override {
299  auto producer = op.getSource().getDefiningOp<GenericOp>();
300  if (!producer || producer.getDpsInits().size() != 1 ||
301  !isMaterializing(producer.getDpsInitOperand(0), false) ||
302  !producer.getResult(0).hasOneUse()) {
303  return failure();
304  }
305  // Clone the materialization operation, but update the result to sparse.
306  rewriter.setInsertionPoint(producer);
307  Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308  Operation *cloned = rewriter.clone(*init);
309  cloned->getResult(0).setType(op.getResult().getType());
310 
311  rewriter.modifyOpInPlace(producer, [&]() {
312  producer.getDpsInitsMutable().assign(cloned->getResults());
313  producer.getResult(0).setType(op.getResult().getType());
314  });
315 
316  rewriter.replaceAllOpUsesWith(op, producer);
317  op->erase();
318 
319  return success();
320  }
321 };
322 
323 /// Rewriting rule that converts direct yield of zero with initial allocation.
324 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325 public:
327 
328  LogicalResult matchAndRewrite(GenericOp op,
329  PatternRewriter &rewriter) const override {
330  if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331  !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332  !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333  return failure();
334  auto outputType = getRankedTensorType(op.getResult(0));
335  // Yielding zero on newly materialized sparse tensor can be
336  // optimized directly (regardless of dynamic or static size).
337  if (getSparseTensorEncoding(outputType)) {
338  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339  return success();
340  }
341  // Use static zero value directly instead of materialization.
342  if (!outputType.hasStaticShape())
343  return failure();
344  Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345  rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346  rewriter.eraseOp(def);
347  return success();
348  }
349 };
350 
351 /// Rewriting rule that converts two kernels:
352 ///
353 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354 /// X(i,j) = S(i,j) * T(i,j)
355 ///
356 /// into a single kernel, using distributive law:
357 ///
358 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
359 ///
360 /// This kind of fusion (merging two ops into one but using arithmetic
361 /// equalities that may not hold for floating-point computations) would
362 /// be undesirable in the dense case, since we distribute the multiplication
363 /// into the reduction loop. However, for sparse sampling tensor S, such
364 /// a fusion may actually reduce the asymptotic complexity of the kernel,
365 /// since intermediate results may be nullified.
366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367 public:
369 
370  LogicalResult matchAndRewrite(GenericOp op,
371  PatternRewriter &rewriter) const override {
372  // Check consumer.
373  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374  op.getNumResults() != 1 ||
375  op.getNumParallelLoops() != op.getNumLoops() ||
376  !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377  !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378  !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379  return failure();
380  // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381  // operand can be sparse or dense, since the point of this rewriting rule
382  // is detecting a situation in which *more* sparsity is introduced into
383  // a computation, be it already sparse or still dense.
384  unsigned other = 0;
385  if (isSparseTensor(op.getDpsInputOperand(0)))
386  other = 1;
387  else if (!isSparseTensor(op.getDpsInputOperand(1)))
388  return failure();
389  // Check producer.
390  auto prod = dyn_cast_or_null<GenericOp>(
391  op.getDpsInputOperand(other)->get().getDefiningOp());
392  if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393  !prod.getResult(0).hasOneUse())
394  return failure();
395  // Sampling consumer and sum of multiplication chain producer.
396  if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397  !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398  !isSampling(op) || !isSumOfMul(prod))
399  return failure();
400  // Modify operand structure of producer and consumer.
401  Location loc = prod.getLoc();
402  SmallVector<Value> inputOps = prod.getInputs();
403  SmallVector<Value> outputOps = op.getOutputs();
404  SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405  inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406  fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
407  // Fuse producer and consumer into a new generic op.
408  auto fusedOp = GenericOp::create(
409  rewriter, loc, op.getResult(0).getType(), inputOps, outputOps,
410  rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411  /*doc=*/nullptr, /*library_call=*/nullptr);
412  Block &prodBlock = prod.getRegion().front();
413  Block &consBlock = op.getRegion().front();
414  IRMapping mapper;
415  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416  unsigned num = prodBlock.getNumArguments();
417  for (unsigned i = 0; i < num - 1; i++)
418  addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419  addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420  addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421  // Clone bodies of the producer and consumer in new evaluation order.
422  auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
423  auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
424  Value last;
425  for (auto &op : prodBlock.without_terminator())
426  if (&op != acc) {
427  last = op.getResult(0);
428  rewriter.clone(op, mapper);
429  }
430  mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
431  mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432  last = rewriter.clone(*acc, mapper)->getResult(0);
433  linalg::YieldOp::create(rewriter, loc, last);
434  // Force initial value on merged allocation for dense outputs.
435  // TODO: deal with non alloc tensor here one day
436  if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437  Value init = prod.getDpsInitOperand(0)
438  ->get()
439  .getDefiningOp<AllocTensorOp>()
440  .getCopy();
441  AllocTensorOp a =
442  op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443  rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444  }
445  // Replace consumer with fused operation. Old producer
446  // and consumer ops will be removed by DCE.
447  rewriter.replaceOp(op, fusedOp->getResults());
448  return success();
449  }
450 
451 private:
452  // Helper to add argument and record the mapping.
453  static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454  mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
455  }
456 };
457 
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
459 // should really not be used to convert between sparse encodings. Since
460 // the pattern currently appears as a result of some prior rewriting
461 // we make an attempt to repair very obvious cases.
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464 public:
466 
467  LogicalResult matchAndRewrite(tensor::CastOp op,
468  PatternRewriter &rewriter) const override {
469  Type srcType = op.getSource().getType();
470  Type dstType = op.getDest().getType();
471  // A nop cast simply folds away.
472  if (srcType == dstType) {
473  rewriter.replaceOp(op, op->getResults());
474  return success();
475  }
476  // See if a sparsity changing cast can be fused into producer.
477  if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
478  if (Operation *def = op.getSource().getDefiningOp()) {
479  if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
480  rewriter.modifyOpInPlace(def, [&]() {
481  def->getResult(0).setType(op->getResultTypes()[0]);
482  });
483  rewriter.replaceOp(op, def->getResult(0));
484  return success();
485  }
486  }
487  }
488  // Repair tensor casts with at least one sparse operand into the
489  // the properly supported sparse_tensor.convert.
490  if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491  rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492  return success();
493  }
494  // Fail otherwise.
495  return failure();
496  }
497 };
498 
499 /// Rewrites a sequence of operations for sparse tensor selections in to
500 /// semi-ring operations such that they can be compiled correctly by the
501 /// sparsifier. E.g., transforming the following sequence
502 ///
503 /// %sel = arith.select %cond, %sp1, %sp2
504 ///
505 /// to
506 ///
507 /// %sel = binary %sp1, %sp2:
508 /// both (%l, %r) {yield select %cond, %l, %r}
509 /// left (%l) {yield select %cond, %l, 0}
510 /// right (%r) {yield select %cond, 0, %r}
511 ///
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
514 /// tri-nary operation.
515 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516 public:
518  LogicalResult matchAndRewrite(GenericOp op,
519  PatternRewriter &rewriter) const override {
520  // Rejects non sparse kernels.
521  if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522  return failure();
523 
524  Location loc = op.getLoc();
526  for (Operation &inst : *op.getBody()) {
527  // Matches pattern.
528  auto matched = isRewritablePattern(op, &inst);
529  if (!matched.has_value())
530  continue;
531 
532  rewriter.setInsertionPoint(&inst);
533  auto [c, t, f] = matched.value();
534  assert(t.getType() == f.getType());
535  auto selTp = t.getType();
536  auto c0 = constantZero(rewriter, loc, selTp);
537  auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f);
538  // Initializes all the blocks.
539  rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540  {t.getLoc(), f.getLoc()});
541  rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542  rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543 
544  for (auto *r : binOp.getRegions()) {
545  Block *b = &r->front();
546  rewriter.setInsertionPointToStart(b);
547 
548  IRMapping irMap;
549  // Clones the cmp operations into the region to make the binary op
550  // admissible.
551  Value newC = c;
552  if (auto *def = c.getDefiningOp())
553  newC = rewriter.clone(*def, irMap)->getResult(0);
554 
555  irMap.map(c, newC);
556  if (r == &binOp.getLeftRegion()) {
557  irMap.map(t, b->getArgument(0));
558  irMap.map(f, c0);
559  } else if (r == &binOp.getRightRegion()) {
560  irMap.map(t, c0);
561  irMap.map(f, b->getArgument(0));
562  } else {
563  irMap.map(t, b->getArgument(0));
564  irMap.map(f, b->getArgument(1));
565  }
566  auto y = rewriter.clone(inst, irMap)->getResult(0);
567  sparse_tensor::YieldOp::create(rewriter, loc, y);
568  }
569 
570  // We successfully rewrited a operation. We can not do replacement here
571  // becuase it invalidate the iterator for the current loop to traverse
572  // the instructions.
573  semiRings.emplace_back(&inst, binOp);
574  }
575 
576  // Finalizes the replacement.
577  for (auto [sel, semi] : semiRings)
578  rewriter.replaceOp(sel, semi->getResults());
579 
580  return success(!semiRings.empty());
581  }
582 
583 private:
584  static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585  isRewritablePattern(GenericOp op, Operation *v) {
586  auto sel = dyn_cast<arith::SelectOp>(v);
587  if (!sel)
588  return std::nullopt;
589 
590  auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591  auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592  // TODO: For simplicity, we only handle cases where both true/false value
593  // are directly loaded the input tensor. We can probably admit more cases
594  // in theory.
595  if (!tVal || !fVal)
596  return std::nullopt;
597 
598  // Helper lambda to determine whether the value is loaded from a dense input
599  // or is a loop invariant.
600  auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601  if (auto bArg = dyn_cast<BlockArgument>(v);
602  bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603  return true;
604  // If the value is defined outside the loop, it is a loop invariant.
605  return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606  };
607 
608  // If the condition value is load directly from a dense tensor or
609  // loop-invariants, we can sparsify the kernel.
610  auto cond = sel.getCondition();
611  if (isValFromDenseInputOrInvariant(cond))
612  return std::make_tuple(cond, tVal, fVal);
613 
614  Value cmpL, cmpR;
615  if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616  matchers::m_Any(&cmpR))) ||
617  matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618  matchers::m_Any(&cmpR)))) {
619  // TODO: we can do it recursively to check whether all the leaf values are
620  // loaded from dense tensors or are loop invariants.
621  if (isValFromDenseInputOrInvariant(cmpL) ||
622  isValFromDenseInputOrInvariant(cmpR))
623  return std::make_tuple(cond, tVal, fVal);
624  }
625 
626  return std::nullopt;
627  };
628 };
629 
630 /// Rewrites a sparse reduction that would not sparsify directly since
631 /// doing so would only iterate over the stored elements, ignoring the
632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633 /// (note that reductions like add/sub/or/xor can directly be sparsified
634 /// since the implicit zeros do not contribute to the final result).
635 /// Note that prod/and are still included since, even though they often
636 /// are nullified in sparse data, they may still occur for special
637 /// situations in which e.g. some rows in a sparse matrix are fully
638 /// dense. For min/max, including the implicit zeros is a much more
639 /// common situation.
640 ///
641 /// TODO: this essentially "densifies" the operation; we want to implement
642 /// this much more efficiently by performing the reduction over the
643 /// stored values, and feed in the zero once if there were *any*
644 /// implicit zeros as well; but for now, at least we provide
645 /// the functionality
646 ///
647 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648 public:
650 
651  LogicalResult matchAndRewrite(GenericOp op,
652  PatternRewriter &rewriter) const override {
653  // Reject non-reductions.
654  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655  op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656  return failure();
657  auto *inp = op.getDpsInputOperand(0);
658  auto *init = op.getDpsInitOperand(0);
659  if (!isSparseTensor(inp))
660  return failure();
661  // Look for direct x = x OP y for semi-ring ready reductions.
662  auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663  .getOperand(0)
664  .getDefiningOp();
665  if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666  arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667  arith::MaxUIOp>(red))
668  return failure();
669  Value s0 = op.getBlock()->getArgument(0);
670  Value s1 = op.getBlock()->getArgument(1);
671  if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672  (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673  return failure();
674  // Identity.
675  Location loc = op.getLoc();
676  Value identity =
677  tensor::ExtractOp::create(rewriter, loc, init->get(), ValueRange());
678  // Unary {
679  // present -> value
680  // absent -> zero.
681  // }
682  Type rtp = s0.getType();
683  rewriter.setInsertionPointToStart(&op.getRegion().front());
684  auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0);
685  Block *present =
686  rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687  rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688  sparse_tensor::YieldOp::create(rewriter, loc, present->getArgument(0));
689  rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690  rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691  auto zero =
692  arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(rtp));
693  sparse_tensor::YieldOp::create(rewriter, loc, zero);
694  rewriter.setInsertionPointAfter(semiring);
695  // CustomReduce {
696  // x = x REDUC y, identity
697  // }
698  auto custom = sparse_tensor::ReduceOp::create(
699  rewriter, loc, rtp, semiring.getResult(), s1, identity);
700  Block *region =
701  rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702  rewriter.setInsertionPointToStart(&custom.getRegion().front());
703  IRMapping irMap;
704  irMap.map(red->getOperand(0), region->getArgument(0));
705  irMap.map(red->getOperand(1), region->getArgument(1));
706  auto *cloned = rewriter.clone(*red, irMap);
707  sparse_tensor::YieldOp::create(rewriter, loc, cloned->getResult(0));
708  rewriter.setInsertionPointAfter(custom);
709  rewriter.replaceOp(red, custom.getResult());
710  return success();
711  }
712 };
713 
714 /// Sparse rewriting rule for the print operator. This operation is mainly used
715 /// for debugging and testing. As such, it lowers to the vector.print operation
716 /// which only require very light-weight runtime support.
717 struct PrintRewriter : public OpRewritePattern<PrintOp> {
718 public:
720  LogicalResult matchAndRewrite(PrintOp op,
721  PatternRewriter &rewriter) const override {
722  Location loc = op.getLoc();
723  auto tensor = op.getTensor();
724  auto stt = getSparseTensorType(tensor);
725  // Header with NSE.
726  auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor);
727  vector::PrintOp::create(
728  rewriter, loc,
729  rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
730  vector::PrintOp::create(rewriter, loc, nse);
731  // Print run-time contents for dim/lvl sizes.
732  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("dim = "));
733  printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
734  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("lvl = "));
735  printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
736  // Use the "codegen" foreach loop construct to iterate over
737  // all typical sparse tensor components for printing.
738  foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
739  &stt](Type, FieldIndex,
741  Level l, LevelType) {
742  switch (kind) {
743  case SparseTensorFieldKind::StorageSpec: {
744  break;
745  }
746  case SparseTensorFieldKind::PosMemRef: {
747  auto lvl = constantIndex(rewriter, loc, l);
748  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("pos["));
749  vector::PrintOp::create(rewriter, loc, lvl,
750  vector::PrintPunctuation::NoPunctuation);
751  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : "));
752  auto pos = ToPositionsOp::create(rewriter, loc, tensor, l);
753  printContents(rewriter, loc, pos);
754  break;
755  }
756  case SparseTensorFieldKind::CrdMemRef: {
757  auto lvl = constantIndex(rewriter, loc, l);
758  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("crd["));
759  vector::PrintOp::create(rewriter, loc, lvl,
760  vector::PrintPunctuation::NoPunctuation);
761  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : "));
762  Value crd = nullptr;
763  // For COO AoS storage, we want to print a single, linear view of
764  // the full coordinate storage at this level. For any other storage,
765  // we show the coordinate storage for every indivual level.
766  if (stt.getAoSCOOStart() == l)
767  crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor);
768  else
769  crd = ToCoordinatesOp::create(rewriter, loc, tensor, l);
770  printContents(rewriter, loc, crd);
771  break;
772  }
773  case SparseTensorFieldKind::ValMemRef: {
774  vector::PrintOp::create(rewriter, loc,
775  rewriter.getStringAttr("values : "));
776  auto val = ToValuesOp::create(rewriter, loc, tensor);
777  printContents(rewriter, loc, val);
778  break;
779  }
780  }
781  return true;
782  });
783  vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("----\n"));
784  rewriter.eraseOp(op);
785  return success();
786  }
787 
788 private:
789  // Helper to print contents of a single memref. For "push_back" vectors,
790  // we assume that the previous getters for pos/crd/val have added a
791  // slice-to-size view to make sure we just print the size and not the
792  // full capacity.
793  //
794  // Generates code to print (1-dim or higher):
795  // ( a0, a1, ... )
796  static void printContents(PatternRewriter &rewriter, Location loc,
797  Value vec) {
798  auto shape = cast<ShapedType>(vec.getType()).getShape();
799  SmallVector<Value> idxs;
800  printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
801  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
802  }
803 
804  // Helper to the helper.
805  static void printContentsLevel(PatternRewriter &rewriter, Location loc,
806  Value vec, unsigned i, ArrayRef<int64_t> shape,
807  SmallVectorImpl<Value> &idxs) {
808  // Open bracket.
809  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
810  // Generate for loop.
811  auto zero = constantIndex(rewriter, loc, 0);
812  auto index = constantIndex(rewriter, loc, i);
813  auto size = memref::DimOp::create(rewriter, loc, vec, index);
814  auto step = constantIndex(rewriter, loc, 1);
815  auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step);
816  idxs.push_back(forOp.getInductionVar());
817  rewriter.setInsertionPointToStart(forOp.getBody());
818  if (i < shape.size() - 1) {
819  // Enter deeper loop nest.
820  printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
821  } else {
822  // Actual contents printing.
823  auto val = memref::LoadOp::create(rewriter, loc, vec, idxs);
824  if (llvm::isa<ComplexType>(val.getType())) {
825  // Since the vector dialect does not support complex types in any op,
826  // we split those into (real, imag) pairs here.
827  Value real = complex::ReOp::create(rewriter, loc, val);
828  Value imag = complex::ImOp::create(rewriter, loc, val);
829  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
830  vector::PrintOp::create(rewriter, loc, real,
831  vector::PrintPunctuation::Comma);
832  vector::PrintOp::create(rewriter, loc, imag,
833  vector::PrintPunctuation::Close);
834  } else {
835  vector::PrintOp::create(rewriter, loc, val,
836  vector::PrintPunctuation::NoPunctuation);
837  }
838  // Terminating comma (except at end).
839  auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step);
840  Value cond = arith::CmpIOp::create(rewriter, loc,
841  arith::CmpIPredicate::ne, bound, size);
842  scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond, /*else*/ false);
843  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
844  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma);
845  }
846  idxs.pop_back();
847  rewriter.setInsertionPointAfter(forOp);
848  // Close bracket.
849  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
850  }
851 
852  // Helper method to print run-time lvl/dim sizes.
853  static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
854  unsigned size, bool isDim) {
855  // Open bracket.
856  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
857  // Print unrolled contents (dimop requires constant value).
858  for (unsigned i = 0; i < size; i++) {
859  auto idx = constantIndex(rewriter, loc, i);
860  Value val;
861  if (isDim)
862  val = tensor::DimOp::create(rewriter, loc, tensor, idx);
863  else
864  val = LvlOp::create(rewriter, loc, tensor, idx);
865  vector::PrintOp::create(rewriter, loc, val,
866  i != size - 1
867  ? vector::PrintPunctuation::Comma
868  : vector::PrintPunctuation::NoPunctuation);
869  }
870  // Close bracket and end of line.
871  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
872  vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
873  }
874 };
875 
876 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
877 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
878 public:
880 
881  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
882  PatternRewriter &rewriter) const override {
883  Location loc = op.getLoc();
884  Value srcTensor = op.getSource();
885  const auto srcTp = tryGetSparseTensorType(srcTensor);
886  const auto dstTp = tryGetSparseTensorType(op.getResult());
887  if (!srcTp || !dstTp)
888  return failure();
889 
890  if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
891  !dstTp->hasStaticDimShape())
892  return failure();
893 
894  SmallVector<Value> srcSizes;
895  sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
896  SmallVector<Value> dstSizes;
897  for (Dimension d : dstTp->getDimShape())
898  dstSizes.push_back(constantIndex(rewriter, loc, d));
899 
900  Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
901  // Only need an unordered COO buffer if input and output are not sorted
902  // in the same way.
903  Type bufferTp = getBufferType(
904  dstTp->withoutDimToLvl(),
905  !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
906  SmallVector<Value> dynSizes;
907  Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes,
908  Value(), nnz, Attribute())
909  .getResult();
910 
911  // Convert src coordinates to dst coordinates by first collapsing it to 1D
912  // and then expand it to the match the rank of the destination tensor.
913  // Implemented as follows:
914  // foreach srcCoords %srcTensor
915  // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916  // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917  // insert expandedCoords, %buffer
918  //
919  // followed by an optional
920  // %t = sparse_tensor.cast %tmp
921  // depending on whether the input/output are sorted in the same way.
922  const auto encSrc = srcTp->getEncoding();
923  ForeachOp foreachOp = ForeachOp::create(
924  rewriter, loc, srcTensor, buffer,
925  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
926  ValueRange reduc) {
927  const Dimension srcRank = srcTp->getDimRank();
928  SmallVector<Value> srcDcvs;
929  srcDcvs.reserve(srcRank);
930  for (Dimension d = 0; d < srcRank; d++) {
931  Level lvl = toLvl(encSrc, d);
932  srcDcvs.push_back(srcLcvs[lvl]);
933  }
934 
935  Value collapseSize = constantIndex(builder, loc, 1);
936  for (Dimension d = 0; d < srcRank; d++)
937  collapseSize =
938  arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]);
939  SmallVector<Value, 1> collapsedSizes = {collapseSize};
940 
941  ReassociationIndices collapseIdx;
942  for (Dimension i = 0; i < srcRank; i++)
943  collapseIdx.push_back(i);
944  SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
945  SmallVector<Value, 1> collapsedDcvs;
946  reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947  collapsedSizes, collapsedDcvs);
948 
949  ReassociationIndices expandIdx;
950  for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951  expandIdx.push_back(i);
952  SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
953  SmallVector<Value> dstDcvs;
954  reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955  dstSizes, dstDcvs);
956 
957  auto t =
958  tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
959  sparse_tensor::YieldOp::create(builder, loc, t);
960  });
961 
962  Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true);
963  if (bufferTp != *dstTp) {
964  auto dstRTT = dstTp->getRankedTensorType();
965  Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
966  DeallocTensorOp::create(rewriter, loc, t);
967  t = converted;
968  }
969  rewriter.replaceOp(op, t);
970  return success();
971  }
972 };
973 
974 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
975 template <typename ReshapeOp>
976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977 public:
979 
980  LogicalResult matchAndRewrite(ReshapeOp op,
981  PatternRewriter &rewriter) const override {
982  Location loc = op.getLoc();
983  Value srcTensor = op.getSrc();
984  const auto srcTp = getSparseTensorType(srcTensor);
985  const auto dstTp = getSparseTensorType(op.getResult());
986  if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987  return failure();
988 
989  // Generate code to represent the static dimension constants or compute
990  // the dynamic dimension values.
991  SmallVector<Value> srcSizes;
992  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
993  SmallVector<Value> dstSizes;
994  SmallVector<Value> dstDynSizes;
995  if (dstTp.hasStaticDimShape()) {
996  for (Dimension d : dstTp.getDimShape())
997  dstSizes.push_back(constantIndex(rewriter, loc, d));
998  } else {
999  ArrayRef<Size> dstShape = dstTp.getDimShape();
1000  genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001  op.getReassociationIndices());
1002  for (auto [idx, shape] : llvm::enumerate(dstShape)) {
1003  if (shape == ShapedType::kDynamic)
1004  dstDynSizes.push_back(dstSizes[idx]);
1005  }
1006  }
1007  Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
1008  // Only need a unordered COO buffer if input and output are not sorted
1009  // in the same way.
1010  Type bufferTp = getBufferType(
1011  dstTp.withoutDimToLvl(),
1012  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013 
1014  Value buffer =
1015  AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(),
1016  /*sizeHint=*/nnz, Attribute())
1017  .getResult();
1018 
1019  // Implement the sparse2sparse reshape as follows:
1020  // foreach srcCoords %srcTensor
1021  // insert reshapeCvs(srcCoords), %buffer
1022  //
1023  // followed by an optional
1024  // %t = sparse_tensor.cast %tmp
1025  // depending on whether the input/output are sorted in the same way.
1026  const auto encSrc = srcTp.getEncoding();
1027  ForeachOp foreachOp = ForeachOp::create(
1028  rewriter, loc, srcTensor, buffer,
1029  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1030  ValueRange reduc) {
1031  const Dimension dimRank = srcTp.getDimRank();
1032  SmallVector<Value> srcDcvs;
1033  srcDcvs.reserve(dimRank);
1034  for (Dimension d = 0; d < dimRank; d++) {
1035  Level lvl = toLvl(encSrc, d);
1036  srcDcvs.push_back(srcLcvs[lvl]);
1037  }
1038  SmallVector<Value> dstDcvs;
1039  reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1040  srcDcvs, dstSizes, dstDcvs);
1041  auto t =
1042  tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
1043  sparse_tensor::YieldOp::create(builder, loc, t);
1044  });
1045 
1046  Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true);
1047  if (bufferTp != dstTp) {
1048  auto dstRTT = dstTp.getRankedTensorType();
1049  Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
1050  DeallocTensorOp::create(rewriter, loc, t);
1051  t = converted;
1052  }
1053  rewriter.replaceOp(op, t);
1054  return success();
1055  }
1056 };
1057 
1058 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1059 /// operator.
1060 template <typename ReshapeOp>
1061 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1062 public:
1064 
1065  LogicalResult matchAndRewrite(ReshapeOp op,
1066  PatternRewriter &rewriter) const override {
1067  Location loc = op->getLoc();
1068  auto encDst = getSparseTensorEncoding(op.getResult().getType());
1069  auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1070  // Since a pure dense expansion is very cheap (change of view), for
1071  // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1072  // conversion from the reshape operation itself.
1073  // All other cases are handled elsewhere.
1074  if (encDst && encSrc) {
1075  return failure();
1076  }
1077  if (encSrc) {
1078  auto rtp = getRankedTensorType(op.getSrc());
1079  auto denseTp =
1080  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1081  auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
1082  rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1083  return success();
1084  }
1085  if (encDst) {
1086  auto rtp = getRankedTensorType(op.getResult());
1087  auto denseTp =
1088  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1089  ReshapeOp reshape;
1090  if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1091  reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1092  op.getReassociation(), op.getOutputShape(),
1093  op.getStaticOutputShape());
1094  } else {
1095  reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1096  op.getReassociation());
1097  }
1098  Value convert = ConvertOp::create(rewriter, loc, rtp, reshape);
1099  rewriter.replaceOp(op, convert);
1100  return success();
1101  }
1102  return failure();
1103  }
1104 };
1105 
1106 // A trivial wrapper to help generate different operations for dense/sparse
1107 // tensors.
1108 struct TensorLike {
1109  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1110  ValueRange sizes) {
1111  SmallVector<Value> dynSzs;
1112  getDynamicSizes(rtt, sizes, dynSzs);
1113 
1114  val = AllocTensorOp::create(builder, loc, rtt, dynSzs);
1115  if (!isSparse()) {
1116  Value c0 = constantZero(builder, loc, rtt.getElementType());
1117  val = linalg::FillOp::create(builder, loc, c0, val).getResult(0);
1118  }
1119  }
1120 
1121  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1122  val = tensor::InsertOp::create(builder, loc, v, val, crds);
1123  }
1124 
1125  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1126  if (isSparse())
1127  return LoadOp::create(builder, loc, val, true);
1128  return val;
1129  }
1130 
1131  bool isSparse() const {
1132  return getSparseTensorEncoding(val.getType()) != nullptr;
1133  }
1134 
1135  Value val;
1136 };
1137 
1138 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1140  LogicalResult matchAndRewrite(tensor::DimOp op,
1141  PatternRewriter &rewriter) const override {
1142  std::optional<int64_t> dim = op.getConstantIndex();
1143  auto stt = tryGetSparseTensorType(op.getSource());
1144  if (!dim || !stt || !stt->hasEncoding())
1145  return failure();
1146 
1147  if (stt->isPermutation()) {
1148  rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1149  toLvl(stt->getEncoding(), *dim));
1150  return success();
1151  }
1152 
1153  // Non-permutation dim2lvl/lvl2dim maps.
1154  // Compute as follows:
1155  // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1156  // Note that it is not the most efficient way (but a more general one) for
1157  // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1158  // computed simply by lvl_size * block_size.
1159  Location loc = op.getLoc();
1160  SmallVector<Value> maxLvlCrds;
1161  for (Level l = 0; l < stt->getLvlRank(); l++) {
1162  Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l);
1163  Value maxLvlCrd = arith::SubIOp::create(
1164  rewriter, loc, lvlSz,
1165  constantOne(rewriter, loc, rewriter.getIndexType()));
1166  maxLvlCrds.push_back(maxLvlCrd);
1167  }
1168 
1169  AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170  Value maxDimCrd = affine::AffineApplyOp::create(
1171  rewriter, op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172  maxLvlCrds);
1173 
1174  Value dimSz = arith::AddIOp::create(
1175  rewriter, loc, maxDimCrd,
1176  constantOne(rewriter, loc, rewriter.getIndexType()));
1177  rewriter.replaceOp(op, dimSz);
1178  return success();
1179  }
1180 };
1181 
1182 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1184  LogicalResult matchAndRewrite(ConcatenateOp op,
1185  PatternRewriter &rewriter) const override {
1186  if (op.needsExtraSort())
1187  op.emitError("ConcatenateOp not staged");
1188 
1189  const Location loc = op.getLoc();
1190  const auto dstTp = getSparseTensorType(op);
1191  const Dimension conDim = op.getDimension();
1192  SmallVector<Value> sizes;
1193  concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1194 
1195  // %t = concatenate %s1, %s2, %s3 {dim = 1}
1196  // ==>
1197  // if (isSparseDst)
1198  // if (allDense)
1199  // %tmp = bufferization.alloc_tensor dstTp
1200  // else
1201  // %tmp = bufferization.alloc_tensor : unordered COO
1202  // else
1203  // %tmp = memref.alloc : dense tensor
1204  // foreach in %s1 : insert d0, d1, %tmp
1205  // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1206  // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1207 
1208  TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1209  Value offset = constantIndex(rewriter, loc, 0);
1210  Value iterArg = dstBuf.val;
1211 
1212  ForeachOp foreachOp;
1213  for (Value input : op.getInputs()) {
1214  // Builds a for op for each input tensor to append new values into the
1215  // output tensor.
1216  foreachOp = ForeachOp::create(
1217  rewriter, loc, input, iterArg,
1218  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1219  ValueRange reduc) {
1220  SmallVector<Value> offDimCrd(dcvs);
1221  offDimCrd[conDim] =
1222  arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset);
1223 
1224  // Enters foreach, updates the SSA chain.
1225  dstBuf.val = reduc.front();
1226  if (!dstTp.isAllDense()) {
1227  Value cond = genIsNonzero(builder, loc, v);
1228  auto ifOp =
1229  scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1230  /*else*/ true);
1231  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1232  scf::YieldOp::create(builder, loc, dstBuf.val);
1233 
1234  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1235  dstBuf.insert(builder, loc, v, offDimCrd);
1236  scf::YieldOp::create(builder, loc, dstBuf.val);
1237 
1238  // Exits the ifOp, update the sparse tensor SSA value.
1239  builder.setInsertionPointAfter(ifOp);
1240  dstBuf.val = ifOp.getResult(0);
1241  } else {
1242  dstBuf.insert(builder, loc, v, offDimCrd);
1243  }
1244  sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1245  });
1246  // Accumulates the offset. Note that only static-shaped inputs are allowed
1247  // by concatenate op verifier, which saves us from computing the offset
1248  // dynamically.
1249  const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1250  assert(ShapedType::isStatic(sz));
1251  offset = arith::AddIOp::create(rewriter, loc, offset,
1252  constantIndex(rewriter, loc, sz));
1253  iterArg = foreachOp.getResult(0);
1254  dstBuf.val = iterArg;
1255  }
1256 
1257  dstBuf.val = iterArg;
1258  Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1259  rewriter.replaceOp(op, ret);
1260  return success();
1261  }
1262 };
1263 
1264 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1266  LogicalResult matchAndRewrite(ConvertOp op,
1267  PatternRewriter &rewriter) const override {
1268  if (op.needsExtraSort())
1269  return op.emitError("ConvertOp not staged.");
1270 
1271  // TODO: Maybe we want a different operation for this too.
1272  auto encDst = getSparseTensorEncoding(op.getType());
1273  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1274  if (encDst && encSrc && !encSrc.isSlice() &&
1275  encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1276  // Trivial tensor conversion and simple element type conversion is handled
1277  // in codegen.
1278  return failure();
1279  }
1280 
1281  Location loc = op.getLoc();
1282  Value src = op.getSource();
1283 
1284  SparseTensorType srcStt = getSparseTensorType(op.getSource());
1285  SparseTensorType dstStt = getSparseTensorType(op.getDest());
1286 
1287  bool fromSparseConst = false;
1288  if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1289  if (isa<SparseElementsAttr>(constOp.getValue()))
1290  fromSparseConst = true;
1291 
1292  const AffineMapAttr foreachOrder =
1293  (!dstStt.isIdentity() && fromSparseConst)
1295  : nullptr;
1296 
1297  bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1298 
1299  SmallVector<Value> sizes;
1300  sizesFromSrc(rewriter, sizes, loc, src);
1301  ValueRange vs;
1302  TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1303 
1304  auto foreachOp = ForeachOp::create(
1305  rewriter, loc, src, dstBuf.val, foreachOrder,
1306  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1307  ValueRange reduc) {
1308  // Enters the loop, update the SSA value for insertion chain.
1309  dstBuf.val = reduc.front();
1310  if (!skipZeroCheck) {
1311  Value cond = genIsNonzero(builder, loc, v);
1312  auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1313  /*else*/ true);
1314  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1315  scf::YieldOp::create(builder, loc, dstBuf.val);
1316 
1317  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1318  dstBuf.insert(builder, loc, v, dcvs);
1319  scf::YieldOp::create(builder, loc, dstBuf.val);
1320 
1321  // Exits the ifOp, update the sparse tensor SSA value.
1322  builder.setInsertionPointAfter(ifOp);
1323  dstBuf.val = ifOp.getResult(0);
1324  } else {
1325  dstBuf.insert(builder, loc, v, dcvs);
1326  }
1327  sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1328  });
1329 
1330  rewriter.setInsertionPointAfter(foreachOp);
1331 
1332  // Exits the for loop, links the SSA chain.
1333  dstBuf.val = foreachOp.getResult(0);
1334 
1335  Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1336  rewriter.replaceOp(op, ret);
1337  return success();
1338  }
1339 };
1340 
1341 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1343  LogicalResult matchAndRewrite(CrdTranslateOp op,
1344  PatternRewriter &rewriter) const override {
1345  AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1346  ? op.getEncoder().getDimToLvl()
1347  : op.getEncoder().getLvlToDim();
1348 
1349  SmallVector<Value> outCrds;
1350  for (AffineExpr result : map.getResults()) {
1351  // TODO: we should probably expand the affine map to IR using our own
1352  // rules, since affine.apply assume signed value, while the cooridinates
1353  // we provided must always be signless.
1354  Value trans = affine::AffineApplyOp::create(
1355  rewriter, op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1356  op.getInCrds());
1357  outCrds.push_back(trans);
1358  }
1359  rewriter.replaceOp(op, outCrds);
1360  return success();
1361  }
1362 };
1363 
1364 /// Sparse rewriting rule for the foreach operator.
1365 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1366 public:
1368 
1369  LogicalResult matchAndRewrite(ForeachOp op,
1370  PatternRewriter &rewriter) const override {
1371 
1372  auto loc = op.getLoc();
1373  Value input = op.getTensor();
1374  SmallVector<Value> reduc = op.getInitArgs();
1375  const auto stt = getSparseTensorType(input);
1376  const Level lvlRank = stt.getLvlRank();
1377 
1378  // Special-case: for each over a sparse constant uses its own rewriting
1379  // rule.
1380  if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1381  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1382  return genForeachOnSparseConstant(op, rewriter, attr);
1383  }
1384  }
1385 
1386  // Otherwise, use loop emitter to generate loops.
1387  const auto enc = stt.getEncoding();
1388 
1389  // 1. Generates loop for the sparse input.
1390  LoopEmitter loopEmitter(
1391  ValueRange{input},
1392  StringAttr::get(getContext(), ForeachOp::getOperationName()));
1393  loopEmitter.initializeLoopEmit(rewriter, loc);
1394  for (Level l = 0; l < lvlRank; l++) {
1395  // TODO: provide utility function for loop sequences that only contains
1396  // one for loop?
1397  const SmallVector<TensorLevel, 1> tidLvls{
1398  loopEmitter.makeTensorLevel(0, l)};
1399  loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1400  // Note that reduc will be taken care of by loop emitter and get updated
1401  // in place.
1402  loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1403  reduc);
1404  }
1405 
1406  SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1407  if (op.getOrder()) {
1408  // TODO: Support it so that we can do direct conversion from CSR->BSR.
1409  llvm_unreachable(
1410  "Level order not yet implemented on non-constant input tensors.");
1411  }
1412 
1413  Value vals = loopEmitter.getValBuffer()[0];
1414  SmallVector<Value> pos = loopEmitter.getValPosits(0);
1415  // Loads the value from sparse tensor using position-index;
1416  // loads the value from dense tensor using coords.
1417  Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos)
1418  : memref::LoadOp::create(rewriter, loc, vals, lcvs);
1419 
1420  // 2. Inline the block in the foreach operator.
1421  Block *srcBlock = op.getBody();
1422 
1423  // Remap coordinates.
1424  SmallVector<Value> args =
1425  enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1426 
1427  // Remap value.
1428  args.push_back(val);
1429  // Remap reduction variables.
1430  args.append(reduc);
1431 
1432  // Remove sparse_tensor.yield.
1433  SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1434  rewriter.eraseOp(srcBlock->getTerminator());
1435 
1436  Operation &last = rewriter.getBlock()->back();
1437  if (llvm::isa<scf::YieldOp>(last)) {
1438  // Because `scf.for` inserts an implicit yield op when there is no
1439  // reduction variable upon creation, we reset the insertion point such
1440  // that the block is inlined before *before* the yield op.
1441  rewriter.setInsertionPoint(&last);
1442  }
1443 
1444  rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1445  rewriter.getInsertionPoint(), args);
1446  rewriter.setInsertionPointToEnd(rewriter.getBlock());
1447  for (Level l = 0; l < lvlRank; l++) {
1448  // Link the reduction chain. Note that loop emitter update the reducValue
1449  // in place.
1450  loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1451  loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1452  }
1453 
1454  // Replace the foreach operator with the value returned by the outtermost
1455  // for loop.
1456  rewriter.replaceOp(op, reducValue);
1457  return success();
1458  }
1459 };
1460 
1461 /// Sparse rewriting rule for the new operator.
1462 struct NewRewriter : public OpRewritePattern<NewOp> {
1464  LogicalResult matchAndRewrite(NewOp op,
1465  PatternRewriter &rewriter) const override {
1466  Location loc = op.getLoc();
1467  auto stt = getSparseTensorType(op.getResult());
1468  if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1469  return failure();
1470 
1471  // Implement the NewOp as follows:
1472  // %orderedCoo = sparse_tensor.new %filename
1473  // %t = sparse_tensor.convert %orderedCoo
1474  // with enveloping reinterpreted_map ops for non-permutations.
1475  RankedTensorType dstTp = stt.getRankedTensorType();
1476  RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1477  Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource());
1478  Value convert = cooTensor;
1479  auto enc = stt.getEncoding();
1480  if (!stt.isPermutation()) { // demap coo, demap dstTp
1481  auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1482  convert = ReinterpretMapOp::create(rewriter, loc, coo, convert);
1483  dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1484  }
1485  convert = ConvertOp::create(rewriter, loc, dstTp, convert);
1486  if (!stt.isPermutation()) // remap to original enc
1487  convert = ReinterpretMapOp::create(rewriter, loc, enc, convert);
1488  rewriter.replaceOp(op, convert);
1489 
1490  // Release the temporary ordered COO tensor.
1491  rewriter.setInsertionPointAfterValue(convert);
1492  DeallocTensorOp::create(rewriter, loc, cooTensor);
1493 
1494  return success();
1495  }
1496 };
1497 
1498 /// Sparse rewriting rule for the out operator.
1499 struct OutRewriter : public OpRewritePattern<OutOp> {
1501  LogicalResult matchAndRewrite(OutOp op,
1502  PatternRewriter &rewriter) const override {
1503  Location loc = op.getLoc();
1504  // Calculate NNZ.
1505  Value src = op.getTensor();
1506  Value nnz = NumberOfEntriesOp::create(rewriter, loc, src);
1507 
1508  // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1509  const auto srcTp = getSparseTensorType(src);
1510  const Dimension dimRank = srcTp.getDimRank();
1511  Type indexTp = rewriter.getIndexType();
1512  Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1513 
1514  // Generate code to calculate dimension size values and store the values to
1515  // the buffer.
1516  SmallVector<Value> dims;
1517  sizesForTensor(rewriter, dims, loc, srcTp, src);
1518  for (Dimension d = 0; d < dimRank; d++) {
1519  memref::StoreOp::create(rewriter, loc, dims[d], dimSizes,
1520  constantIndex(rewriter, loc, d));
1521  }
1522 
1523  // Create a sparse tensor writer and output meta data.
1524  Type opaqueTp = getOpaquePointerType(rewriter);
1525  Value writer =
1526  createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1527  {op.getDest()}, EmitCInterface::Off)
1528  .getResult(0);
1529  Value rankValue = constantIndex(rewriter, loc, dimRank);
1530  createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1531  {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1532 
1533  Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1534  Type eltTp = srcTp.getElementType();
1535  SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1536  primaryTypeFunctionSuffix(eltTp)};
1537  Value value = genAllocaScalar(rewriter, loc, eltTp);
1538  ModuleOp module = op->getParentOfType<ModuleOp>();
1539 
1540  // For each element in the source tensor, output the element.
1541  ForeachOp::create(
1542  rewriter, loc, src, ValueRange(),
1543  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1544  ValueRange reduc) {
1545  for (Dimension d = 0; d < dimRank; d++) {
1546  memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords,
1547  constantIndex(builder, loc, d));
1548  }
1549  memref::StoreOp::create(rewriter, loc, v, value);
1550  SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1551  FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1552  EmitCInterface::On);
1553  func::CallOp::create(builder, loc, TypeRange(), fn, operands);
1554  sparse_tensor::YieldOp::create(builder, loc);
1555  });
1556 
1557  // Release the writer.
1558  createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1559  EmitCInterface::Off);
1560 
1561  rewriter.eraseOp(op);
1562  return success();
1563  }
1564 };
1565 
1566 } // namespace
1567 
1568 //===---------------------------------------------------------------------===//
1569 // Methods that add patterns described in this file to a pattern list.
1570 //===---------------------------------------------------------------------===//
1571 
1573  patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1574  FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1575  GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1576  patterns.getContext());
1577 }
1578 
1580  bool enableRT,
1581  bool enableConvert) {
1582  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1583  ReshapeRewriter<tensor::CollapseShapeOp>,
1584  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1585  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1586  SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1587  patterns.getContext());
1588 
1589  if (enableConvert)
1590  patterns.add<DirectConvertRewriter>(patterns.getContext());
1591  if (!enableRT)
1592  patterns.add<NewRewriter>(patterns.getContext());
1593 }
1594 
1596  // Run CrdTranslateRewriter later in the pipeline so that operation can be
1597  // folded before lowering to affine.apply
1598  patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1599 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:313
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1374
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:95
auto m_Any()
Definition: Matchers.h:537
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:160
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:186
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:128
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238