MLIR  19.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
15 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
32 
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
37 
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
41 
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45 }
46 
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49  auto enc = getSparseTensorEncoding(v.getType());
50  return enc && !llvm::all_of(enc.getLvlTypes(),
51  [](auto lt) { return lt == LevelFormat::Dense; });
52 }
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54 
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57  Value val = op->get();
58  // Check allocation, with zero alloc when required.
59  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60  Value copy = alloc.getCopy();
61  if (isZero)
62  return copy && isZeroValue(copy);
63  return !copy;
64  }
65  // Check for empty tensor materialization.
66  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67  return !isZero;
68  // Last resort for zero alloc: the whole value is zero.
69  return isZero && isZeroValue(val);
70 }
71 
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77  // Both scalar input arguments used exactly once.
78  Value s1 = op.getBlock()->getArgument(0);
79  Value s2 = op.getBlock()->getArgument(1);
80  return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81  (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82  }
83  }
84  return false;
85 }
86 
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89  if (auto arg = dyn_cast<BlockArgument>(val))
90  return arg != x;
91  if (auto *def = val.getDefiningOp()) {
92  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93  return isMulChain(def->getOperand(0), x) &&
94  isMulChain(def->getOperand(1), x);
95  }
96  return false;
97 }
98 
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104  Value x = op.getBlock()->getArguments().back();
105  return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106  (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107  }
108  }
109  return false;
110 }
111 
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115  if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116  if (arg.getOwner()->getParentOp() == op) {
117  return isZeroValue(op->getOperand(arg.getArgNumber()));
118  }
119  }
120  return isZeroValue(yieldOp.getOperand(0));
121 }
122 
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126  Location loc, ShapedType stp, Value tensor) {
127  for (const auto &d : enumerate(stp.getShape())) {
128  Value dim;
129  if (d.value() == ShapedType::kDynamic)
130  dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131  else
132  dim = constantIndex(builder, loc, d.value());
133  sizes.push_back(dim);
134  }
135 }
136 
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138  bool needTmpCOO) {
139  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140  : stt.getRankedTensorType();
141 }
142 
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147  SmallVectorImpl<Value> &dynSizes) {
148  for (const auto &d : enumerate(tp.getShape())) {
149  if (d.value() == ShapedType::kDynamic)
150  dynSizes.push_back(sizes[d.index()]);
151  }
152 }
153 
155  RewriterBase &rewriter,
156  SparseElementsAttr attr) {
157  auto loc = op.getLoc();
158  SmallVector<Value> reduc = op.getInitArgs();
159 
160  // Foreach on constant.
162  rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163  [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164  SmallVector<Value> args;
165  args.append(cvs.begin(), cvs.end());
166  args.push_back(v);
167  args.append(reduc);
168  // Clones the foreach op to get a copy of the loop body.
169  auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170  assert(args.size() == cloned.getBody()->getNumArguments());
171  Operation *yield = cloned.getBody()->getTerminator();
172  rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173  // clean up
174  rewriter.eraseOp(cloned);
175  reduc = yield->getOperands();
176  rewriter.eraseOp(yield);
177  });
178 
179  rewriter.replaceOp(op, reduc);
180  return success();
181 }
182 
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186  SmallVectorImpl<Value> &sizes, Location loc,
187  ShapedType dstTp, ValueRange srcs,
188  unsigned dim) {
189  auto dstShape = dstTp.getShape();
190  sizesFromSrc(builder, sizes, loc, srcs[0]);
191 
192  // Sum up on the `dim` if the dimension is dynamic.
193  if (dstShape[dim] != ShapedType::kDynamic) {
194  // Faithfully take the static size.
195  sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196  } else {
197  // Else, compute the shape dynamically.
198  for (const auto &src : srcs.drop_front()) {
199  Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200  // Sum up all the sizes.
201  sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202  }
203  }
204 }
205 
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
209 
210 namespace {
211 
212 /// TODO: move it to tensor dialect instead.
213 ///
214 /// Fold `tensor.concat` and `tensor.extract_slice`
215 ///
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222 ///
223 /// Becomes
224 ///
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227  : public OpRewritePattern<tensor::ExtractSliceOp> {
229 
230  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231  PatternRewriter &rewriter) const override {
232  auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233  if (!concatOp)
234  return failure();
235 
236  Location loc = extractOp.getLoc();
237  int64_t dim = concatOp.getDim();
238  int64_t rank = extractOp.getResultType().getRank();
239 
240  SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241  SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242 
243  // Compute the partial sums for the slice offsets.
244  AffineExpr sum = rewriter.getAffineDimExpr(0);
245  SmallVector<AffineExpr> partialSums = {sum};
246  SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247  for (auto [idx, input] :
248  llvm::enumerate(concatOp.getInputs().drop_back())) {
249  sum = sum + rewriter.getAffineDimExpr(idx + 1);
250  partialSums.push_back(sum);
251  offsetStrides.push_back(
252  rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253  }
254  auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255  partialSums, rewriter.getContext());
256  SmallVector<OpFoldResult> dimOffsets =
258  rewriter, loc, partialSumMap, offsetStrides);
259 
260  auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261  for (auto [l, r] : llvm::zip(lhs, rhs)) {
262  std::optional<int64_t> staticVal = getConstantIntValue(l);
263  if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264  return false;
265  }
266  return lhs.size() == rhs.size();
267  };
268 
269  for (auto [i, input, offset] :
270  llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271  SmallVector<OpFoldResult> srcSizes =
272  tensor::getMixedSizes(rewriter, loc, input);
273  srcOffsets[dim] = offset;
274 
275  SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276  SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277  SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278 
279  if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280  allEqual(srcStrides, dstStrides)) {
281  Value operand = concatOp.getOperand(i);
282  if (operand.getType() == extractOp.getResultType())
283  rewriter.replaceOp(extractOp, operand);
284  break;
285  }
286  }
287 
288  return success();
289  }
290 };
291 
292 /// Rewriting rule that converts direct yield of zero with initial allocation.
293 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
294 public:
296 
297  LogicalResult matchAndRewrite(GenericOp op,
298  PatternRewriter &rewriter) const override {
299  if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
300  !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
301  !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
302  return failure();
303  auto outputType = getRankedTensorType(op.getResult(0));
304  // Yielding zero on newly materialized sparse tensor can be
305  // optimized directly (regardless of dynamic or static size).
306  if (getSparseTensorEncoding(outputType)) {
307  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
308  return success();
309  }
310  // Use static zero value directly instead of materialization.
311  if (!outputType.hasStaticShape())
312  return failure();
313  Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
314  rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
315  rewriter.eraseOp(def);
316  return success();
317  }
318 };
319 
320 /// Rewriting rule that converts two kernels:
321 ///
322 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
323 /// X(i,j) = S(i,j) * T(i,j)
324 ///
325 /// into a single kernel, using distributive law:
326 ///
327 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
328 ///
329 /// This kind of fusion (merging two ops into one but using arithmetic
330 /// equalities that may not hold for floating-point computations) would
331 /// be undesirable in the dense case, since we distribute the multiplication
332 /// into the reduction loop. However, for sparse sampling tensor S, such
333 /// a fusion may actually reduce the asymptotic complexity of the kernel,
334 /// since intermediate results may be nullified.
335 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
336 public:
338 
339  LogicalResult matchAndRewrite(GenericOp op,
340  PatternRewriter &rewriter) const override {
341  // Check consumer.
342  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
343  op.getNumResults() != 1 ||
344  op.getNumParallelLoops() != op.getNumLoops() ||
345  !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
346  !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
347  !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
348  return failure();
349  // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
350  // operand can be sparse or dense, since the point of this rewriting rule
351  // is detecting a situation in which *more* sparsity is introduced into
352  // a computation, be it already sparse or still dense.
353  unsigned other = 0;
354  if (isSparseTensor(op.getDpsInputOperand(0)))
355  other = 1;
356  else if (!isSparseTensor(op.getDpsInputOperand(1)))
357  return failure();
358  // Check producer.
359  auto prod = dyn_cast_or_null<GenericOp>(
360  op.getDpsInputOperand(other)->get().getDefiningOp());
361  if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
362  !prod.getResult(0).hasOneUse())
363  return failure();
364  // Sampling consumer and sum of multiplication chain producer.
365  if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
366  !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
367  !isSampling(op) || !isSumOfMul(prod))
368  return failure();
369  // Modify operand structure of producer and consumer.
370  Location loc = prod.getLoc();
371  SmallVector<Value> inputOps = prod.getInputs();
372  SmallVector<Value> outputOps = op.getOutputs();
373  SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
374  inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
375  fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
376  // Fuse producer and consumer into a new generic op.
377  auto fusedOp = rewriter.create<GenericOp>(
378  loc, op.getResult(0).getType(), inputOps, outputOps,
379  rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
380  /*doc=*/nullptr, /*library_call=*/nullptr);
381  Block &prodBlock = prod.getRegion().front();
382  Block &consBlock = op.getRegion().front();
383  IRMapping mapper;
384  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
385  unsigned num = prodBlock.getNumArguments();
386  for (unsigned i = 0; i < num - 1; i++)
387  addArg(mapper, fusedBlock, prodBlock.getArgument(i));
388  addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
389  addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
390  // Clone bodies of the producer and consumer in new evaluation order.
391  auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
392  auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
393  Value last;
394  for (auto &op : prodBlock.without_terminator())
395  if (&op != acc) {
396  last = op.getResult(0);
397  rewriter.clone(op, mapper);
398  }
399  mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
400  mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
401  last = rewriter.clone(*acc, mapper)->getResult(0);
402  rewriter.create<linalg::YieldOp>(loc, last);
403  // Force initial value on merged allocation for dense outputs.
404  // TODO: deal with non alloc tensor here one day
405  if (!getSparseTensorEncoding(op.getResult(0).getType())) {
406  Value init = prod.getDpsInitOperand(0)
407  ->get()
408  .getDefiningOp<AllocTensorOp>()
409  .getCopy();
410  AllocTensorOp a =
411  op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
412  rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
413  }
414  // Replace consumer with fused operation. Old producer
415  // and consumer ops will be removed by DCE.
416  rewriter.replaceOp(op, fusedOp->getResults());
417  return success();
418  }
419 
420 private:
421  // Helper to add argument and record the mapping.
422  static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
423  mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
424  }
425 };
426 
427 // Fuse a tensor cast into producing operation. Note that a tensor.cast
428 // should really not be used to convert between sparse encodings. Since
429 // the pattern currently appears as a result of some prior rewriting
430 // we make an attempt to repair very obvious cases.
431 // TODO: audit the pure tensor dialect rewriting rules
432 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
433 public:
435 
436  LogicalResult matchAndRewrite(tensor::CastOp op,
437  PatternRewriter &rewriter) const override {
438  Type srcType = op.getSource().getType();
439  Type dstType = op.getDest().getType();
440  // A nop cast simply folds away.
441  if (srcType == dstType) {
442  rewriter.replaceOp(op, op->getResults());
443  return success();
444  }
445  // See if a sparsity changing cast can be fused into producer.
446  if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
447  if (Operation *def = op.getSource().getDefiningOp()) {
448  if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
449  rewriter.modifyOpInPlace(def, [&]() {
450  def->getResult(0).setType(op->getResultTypes()[0]);
451  });
452  rewriter.replaceOp(op, def->getResult(0));
453  return success();
454  }
455  }
456  }
457  // Repair tensor casts with at least one sparse operand into the
458  // the properly supported sparse_tensor.convert.
459  if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
460  rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
461  return success();
462  }
463  // Fail otherwise.
464  return failure();
465  }
466 };
467 
468 /// Rewrites a sequence of operations for sparse tensor selections in to
469 /// semi-ring operations such that they can be compiled correctly by the
470 /// sparsifier. E.g., transforming the following sequence
471 ///
472 /// %sel = arith.select %cond, %sp1, %sp2
473 ///
474 /// to
475 ///
476 /// %sel = binary %sp1, %sp2:
477 /// both (%l, %r) {yield select %cond, %l, %r}
478 /// left (%l) {yield select %cond, %l, 0}
479 /// right (%r) {yield select %cond, 0, %r}
480 ///
481 /// TODO: We require that the tensor used for extracting conditions to be dense
482 /// to sparsify the code. To support a sparse condition tensor, we need a
483 /// tri-nary operation.
484 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
485 public:
487  LogicalResult matchAndRewrite(GenericOp op,
488  PatternRewriter &rewriter) const override {
489  // Rejects non sparse kernels.
490  if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
491  return failure();
492 
493  Location loc = op.getLoc();
495  for (Operation &inst : *op.getBody()) {
496  // Matches pattern.
497  auto matched = isRewritablePattern(op, &inst);
498  if (!matched.has_value())
499  continue;
500 
501  rewriter.setInsertionPoint(&inst);
502  auto [c, t, f] = matched.value();
503  assert(t.getType() == f.getType());
504  auto selTp = t.getType();
505  auto c0 = constantZero(rewriter, loc, selTp);
506  auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
507  // Initializes all the blocks.
508  rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
509  {t.getLoc(), f.getLoc()});
510  rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
511  rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
512 
513  for (auto *r : binOp.getRegions()) {
514  Block *b = &r->front();
515  rewriter.setInsertionPointToStart(b);
516 
517  IRMapping irMap;
518  // Clones the cmp operations into the region to make the binary op
519  // admissible.
520  Value newC = c;
521  if (auto *def = c.getDefiningOp())
522  newC = rewriter.clone(*def, irMap)->getResult(0);
523 
524  irMap.map(c, newC);
525  if (r == &binOp.getLeftRegion()) {
526  irMap.map(t, b->getArgument(0));
527  irMap.map(f, c0);
528  } else if (r == &binOp.getRightRegion()) {
529  irMap.map(t, c0);
530  irMap.map(f, b->getArgument(0));
531  } else {
532  irMap.map(t, b->getArgument(0));
533  irMap.map(f, b->getArgument(1));
534  }
535  auto y = rewriter.clone(inst, irMap)->getResult(0);
536  rewriter.create<sparse_tensor::YieldOp>(loc, y);
537  }
538 
539  // We successfully rewrited a operation. We can not do replacement here
540  // becuase it invalidate the iterator for the current loop to traverse
541  // the instructions.
542  semiRings.emplace_back(&inst, binOp);
543  }
544 
545  // Finalizes the replacement.
546  for (auto [sel, semi] : semiRings)
547  rewriter.replaceOp(sel, semi->getResults());
548 
549  return success(!semiRings.empty());
550  }
551 
552 private:
553  static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
554  isRewritablePattern(GenericOp op, Operation *v) {
555  auto sel = dyn_cast<arith::SelectOp>(v);
556  if (!sel)
557  return std::nullopt;
558 
559  auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
560  auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
561  // TODO: For simplicity, we only handle cases where both true/false value
562  // are directly loaded the input tensor. We can probably admit more cases
563  // in theory.
564  if (!tVal || !fVal)
565  return std::nullopt;
566 
567  // Helper lambda to determine whether the value is loaded from a dense input
568  // or is a loop invariant.
569  auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
570  if (auto bArg = dyn_cast<BlockArgument>(v);
571  bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
572  return true;
573  // If the value is defined outside the loop, it is a loop invariant.
574  return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
575  };
576 
577  // If the condition value is load directly from a dense tensor or
578  // loop-invariants, we can sparsify the kernel.
579  auto cond = sel.getCondition();
580  if (isValFromDenseInputOrInvariant(cond))
581  return std::make_tuple(cond, tVal, fVal);
582 
583  Value cmpL, cmpR;
584  if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
585  matchers::m_Any(&cmpR))) ||
586  matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
587  matchers::m_Any(&cmpR)))) {
588  // TODO: we can do it recursively to check whether all the leaf values are
589  // loaded from dense tensors or are loop invariants.
590  if (isValFromDenseInputOrInvariant(cmpL) ||
591  isValFromDenseInputOrInvariant(cmpR))
592  return std::make_tuple(cond, tVal, fVal);
593  }
594 
595  return std::nullopt;
596  };
597 };
598 
599 /// Rewrites a sparse reduction that would not sparsify directly since
600 /// doing so would only iterate over the stored elements, ignoring the
601 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
602 /// (note that reductions like add/sub/or/xor can directly be sparsified
603 /// since the implicit zeros do not contribute to the final result).
604 /// Note that prod/and are still included since, even though they often
605 /// are nullified in sparse data, they may still occur for special
606 /// situations in which e.g. some rows in a sparse matrix are fully
607 /// dense. For min/max, including the implicit zeros is a much more
608 /// common situation.
609 ///
610 /// TODO: this essentially "densifies" the operation; we want to implement
611 /// this much more efficiently by performing the reduction over the
612 /// stored values, and feed in the zero once if there were *any*
613 /// implicit zeros as well; but for now, at least we provide
614 /// the functionality
615 ///
616 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
617 public:
619 
620  LogicalResult matchAndRewrite(GenericOp op,
621  PatternRewriter &rewriter) const override {
622  // Reject non-reductions.
623  if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
624  op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
625  return failure();
626  auto *inp = op.getDpsInputOperand(0);
627  auto *init = op.getDpsInitOperand(0);
628  if (!isSparseTensor(inp))
629  return failure();
630  // Look for direct x = x OP y for semi-ring ready reductions.
631  auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
632  .getOperand(0)
633  .getDefiningOp();
634  if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
635  arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
636  arith::MaxUIOp>(red))
637  return failure();
638  Value s0 = op.getBlock()->getArgument(0);
639  Value s1 = op.getBlock()->getArgument(1);
640  if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
641  (red->getOperand(0) != s1 || red->getOperand(1) != s0))
642  return failure();
643  // Identity.
644  Location loc = op.getLoc();
645  Value identity =
646  rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
647  // Unary {
648  // present -> value
649  // absent -> zero.
650  // }
651  Type rtp = s0.getType();
652  rewriter.setInsertionPointToStart(&op.getRegion().front());
653  auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
654  Block *present =
655  rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
656  rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
657  rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
658  rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
659  rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
660  auto zero =
661  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
662  rewriter.create<sparse_tensor::YieldOp>(loc, zero);
663  rewriter.setInsertionPointAfter(semiring);
664  // CustomReduce {
665  // x = x REDUC y, identity
666  // }
667  auto custom = rewriter.create<sparse_tensor::ReduceOp>(
668  loc, rtp, semiring.getResult(), s1, identity);
669  Block *region =
670  rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
671  rewriter.setInsertionPointToStart(&custom.getRegion().front());
672  IRMapping irMap;
673  irMap.map(red->getOperand(0), region->getArgument(0));
674  irMap.map(red->getOperand(1), region->getArgument(1));
675  auto *cloned = rewriter.clone(*red, irMap);
676  rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
677  rewriter.setInsertionPointAfter(custom);
678  rewriter.replaceOp(red, custom.getResult());
679  return success();
680  }
681 };
682 
683 /// Sparse rewriting rule for the print operator. This operation is mainly used
684 /// for debugging and testing. As such, it lowers to the vector.print operation
685 /// which only require very light-weight runtime support.
686 struct PrintRewriter : public OpRewritePattern<PrintOp> {
687 public:
689  LogicalResult matchAndRewrite(PrintOp op,
690  PatternRewriter &rewriter) const override {
691  Location loc = op.getLoc();
692  auto tensor = op.getTensor();
693  auto stt = getSparseTensorType(tensor);
694  // Header with NSE.
695  auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
696  rewriter.create<vector::PrintOp>(
697  loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
698  rewriter.create<vector::PrintOp>(loc, nse);
699  // Print run-time contents for dim/lvl sizes.
700  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
701  printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
702  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
703  printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
704  // Use the "codegen" foreach loop construct to iterate over
705  // all typical sparse tensor components for printing.
706  foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
707  &stt](Type, FieldIndex,
709  Level l, LevelType) {
710  switch (kind) {
711  case SparseTensorFieldKind::StorageSpec: {
712  break;
713  }
714  case SparseTensorFieldKind::PosMemRef: {
715  auto lvl = constantIndex(rewriter, loc, l);
716  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
717  rewriter.create<vector::PrintOp>(
718  loc, lvl, vector::PrintPunctuation::NoPunctuation);
719  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
720  auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
721  printContents(rewriter, loc, pos);
722  break;
723  }
724  case SparseTensorFieldKind::CrdMemRef: {
725  auto lvl = constantIndex(rewriter, loc, l);
726  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
727  rewriter.create<vector::PrintOp>(
728  loc, lvl, vector::PrintPunctuation::NoPunctuation);
729  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
730  Value crd = nullptr;
731  // For COO AoS storage, we want to print a single, linear view of
732  // the full coordinate storage at this level. For any other storage,
733  // we show the coordinate storage for every indivual level.
734  if (stt.getAoSCOOStart() == l)
735  crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
736  else
737  crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
738  printContents(rewriter, loc, crd);
739  break;
740  }
741  case SparseTensorFieldKind::ValMemRef: {
742  rewriter.create<vector::PrintOp>(loc,
743  rewriter.getStringAttr("values : "));
744  auto val = rewriter.create<ToValuesOp>(loc, tensor);
745  printContents(rewriter, loc, val);
746  break;
747  }
748  }
749  return true;
750  });
751  rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
752  rewriter.eraseOp(op);
753  return success();
754  }
755 
756 private:
757  // Helper to print contents of a single memref. Note that for the "push_back"
758  // vectors, this prints the full capacity, not just the size. This is done
759  // on purpose, so that clients see how much storage has been allocated in
760  // total. Contents of the extra capacity in the buffer may be uninitialized
761  // (unless the flag enable-buffer-initialization is set to true).
762  //
763  // Generates code to print:
764  // ( a0, a1, ... )
765  static void printContents(PatternRewriter &rewriter, Location loc,
766  Value vec) {
767  // Open bracket.
768  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
769  // For loop over elements.
770  auto zero = constantIndex(rewriter, loc, 0);
771  auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
772  auto step = constantIndex(rewriter, loc, 1);
773  auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
774  rewriter.setInsertionPointToStart(forOp.getBody());
775  auto idx = forOp.getInductionVar();
776  auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
777  if (llvm::isa<ComplexType>(val.getType())) {
778  // Since the vector dialect does not support complex types in any op,
779  // we split those into (real, imag) pairs here.
780  Value real = rewriter.create<complex::ReOp>(loc, val);
781  Value imag = rewriter.create<complex::ImOp>(loc, val);
782  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
783  rewriter.create<vector::PrintOp>(loc, real,
784  vector::PrintPunctuation::Comma);
785  rewriter.create<vector::PrintOp>(loc, imag,
786  vector::PrintPunctuation::Close);
787  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
788  } else {
789  rewriter.create<vector::PrintOp>(loc, val,
790  vector::PrintPunctuation::Comma);
791  }
792  rewriter.setInsertionPointAfter(forOp);
793  // Close bracket and end of line.
794  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
795  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
796  }
797 
798  // Helper method to print run-time lvl/dim sizes.
799  static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
800  unsigned size, bool isDim) {
801  // Open bracket.
802  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
803  // Print unrolled contents (dimop requires constant value).
804  for (unsigned i = 0; i < size; i++) {
805  auto idx = constantIndex(rewriter, loc, i);
806  Value val;
807  if (isDim)
808  val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
809  else
810  val = rewriter.create<LvlOp>(loc, tensor, idx);
811  rewriter.create<vector::PrintOp>(
812  loc, val,
813  i != size - 1 ? vector::PrintPunctuation::Comma
814  : vector::PrintPunctuation::NoPunctuation);
815  }
816  // Close bracket and end of line.
817  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
818  rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
819  }
820 };
821 
822 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
823 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
824 public:
826 
827  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
828  PatternRewriter &rewriter) const override {
829  Location loc = op.getLoc();
830  Value srcTensor = op.getSource();
831  const auto srcTp = getSparseTensorType(srcTensor);
832  const auto dstTp = getSparseTensorType(op.getResult());
833 
834  if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
835  !dstTp.hasStaticDimShape())
836  return failure();
837 
838  SmallVector<Value> srcSizes;
839  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
840  SmallVector<Value> dstSizes;
841  for (Dimension d : dstTp.getDimShape())
842  dstSizes.push_back(constantIndex(rewriter, loc, d));
843 
844  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
845  // Only need an unordered COO buffer if input and output are not sorted
846  // in the same way.
847  Type bufferTp = getBufferType(
848  dstTp.withoutDimToLvl(),
849  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
850  SmallVector<Value> dynSizes;
851  Value buffer = rewriter
852  .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
853  nnz, Attribute())
854  .getResult();
855 
856  // Convert src coordinates to dst coordinates by first collapsing it to 1D
857  // and then expand it to the match the rank of the destination tensor.
858  // Implemented as follows:
859  // foreach srcCoords %srcTensor
860  // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
861  // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
862  // insert expandedCoords, %buffer
863  //
864  // followed by an optional
865  // %t = sparse_tensor.cast %tmp
866  // depending on whether the input/output are sorted in the same way.
867  const auto encSrc = srcTp.getEncoding();
868  ForeachOp foreachOp = rewriter.create<ForeachOp>(
869  loc, srcTensor, buffer,
870  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
871  ValueRange reduc) {
872  const Dimension srcRank = srcTp.getDimRank();
873  SmallVector<Value> srcDcvs;
874  srcDcvs.reserve(srcRank);
875  for (Dimension d = 0; d < srcRank; d++) {
876  Level lvl = toLvl(encSrc, d);
877  srcDcvs.push_back(srcLcvs[lvl]);
878  }
879 
880  Value collapseSize = constantIndex(builder, loc, 1);
881  for (Dimension d = 0; d < srcRank; d++)
882  collapseSize =
883  builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
884  SmallVector<Value, 1> collapsedSizes = {collapseSize};
885 
886  ReassociationIndices collapseIdx;
887  for (Dimension i = 0; i < srcRank; i++)
888  collapseIdx.push_back(i);
889  SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
890  SmallVector<Value, 1> collapsedDcvs;
891  reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
892  collapsedSizes, collapsedDcvs);
893 
894  ReassociationIndices expandIdx;
895  for (Dimension i = 0; i < dstTp.getDimRank(); i++)
896  expandIdx.push_back(i);
897  SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
898  SmallVector<Value> dstDcvs;
899  reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
900  dstSizes, dstDcvs);
901 
902  auto t =
903  builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
904  builder.create<sparse_tensor::YieldOp>(loc, t);
905  });
906 
907  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
908  if (bufferTp != dstTp) {
909  auto dstRTT = dstTp.getRankedTensorType();
910  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
911  rewriter.create<DeallocTensorOp>(loc, t);
912  t = converted;
913  }
914  rewriter.replaceOp(op, t);
915  return success();
916  }
917 };
918 
919 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
920 template <typename ReshapeOp>
921 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
922 public:
924 
925  LogicalResult matchAndRewrite(ReshapeOp op,
926  PatternRewriter &rewriter) const override {
927  Location loc = op.getLoc();
928  Value srcTensor = op.getSrc();
929  const auto srcTp = getSparseTensorType(srcTensor);
930  const auto dstTp = getSparseTensorType(op.getResult());
931  if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
932  return failure();
933 
934  // Generate code to represent the static dimension constants or compute
935  // the dynamic dimension values.
936  SmallVector<Value> srcSizes;
937  sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
938  SmallVector<Value> dstSizes;
939  SmallVector<Value> dstDynSizes;
940  if (dstTp.hasStaticDimShape()) {
941  for (Dimension d : dstTp.getDimShape())
942  dstSizes.push_back(constantIndex(rewriter, loc, d));
943  } else {
944  ArrayRef<Size> dstShape = dstTp.getDimShape();
945  genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
946  op.getReassociationIndices());
947  for (auto [idx, shape] : llvm::enumerate(dstShape)) {
948  if (shape == ShapedType::kDynamic)
949  dstDynSizes.push_back(dstSizes[idx]);
950  }
951  }
952  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
953  // Only need a unordered COO buffer if input and output are not sorted
954  // in the same way.
955  Type bufferTp = getBufferType(
956  dstTp.withoutDimToLvl(),
957  !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
958 
959  Value buffer =
960  rewriter
961  .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
962  /*sizeHint=*/nnz, Attribute())
963  .getResult();
964 
965  // Implement the sparse2sparse reshape as follows:
966  // foreach srcCoords %srcTensor
967  // insert reshapeCvs(srcCoords), %buffer
968  //
969  // followed by an optional
970  // %t = sparse_tensor.cast %tmp
971  // depending on whether the input/output are sorted in the same way.
972  const auto encSrc = srcTp.getEncoding();
973  ForeachOp foreachOp = rewriter.create<ForeachOp>(
974  loc, srcTensor, buffer,
975  [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
976  ValueRange reduc) {
977  const Dimension dimRank = srcTp.getDimRank();
978  SmallVector<Value> srcDcvs;
979  srcDcvs.reserve(dimRank);
980  for (Dimension d = 0; d < dimRank; d++) {
981  Level lvl = toLvl(encSrc, d);
982  srcDcvs.push_back(srcLcvs[lvl]);
983  }
984  SmallVector<Value> dstDcvs;
985  reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
986  srcDcvs, dstSizes, dstDcvs);
987  auto t =
988  builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
989  builder.create<sparse_tensor::YieldOp>(loc, t);
990  });
991 
992  Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
993  if (bufferTp != dstTp) {
994  auto dstRTT = dstTp.getRankedTensorType();
995  Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
996  rewriter.create<DeallocTensorOp>(loc, t);
997  t = converted;
998  }
999  rewriter.replaceOp(op, t);
1000  return success();
1001  }
1002 };
1003 
1004 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1005 /// operator.
1006 template <typename ReshapeOp>
1007 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1008 public:
1010 
1011  LogicalResult matchAndRewrite(ReshapeOp op,
1012  PatternRewriter &rewriter) const override {
1013  Location loc = op->getLoc();
1014  auto encDst = getSparseTensorEncoding(op.getResult().getType());
1015  auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1016  // Since a pure dense expansion is very cheap (change of view), for
1017  // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1018  // conversion from the reshape operation itself.
1019  // All other cases are handled elsewhere.
1020  if (encDst && encSrc) {
1021  return failure();
1022  }
1023  if (encSrc) {
1024  auto rtp = getRankedTensorType(op.getSrc());
1025  auto denseTp =
1026  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1027  auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
1028  rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1029  return success();
1030  }
1031  if (encDst) {
1032  auto rtp = getRankedTensorType(op.getResult());
1033  auto denseTp =
1034  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1035  auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
1036  op.getReassociation());
1037  Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
1038  rewriter.replaceOp(op, convert);
1039  return success();
1040  }
1041  return failure();
1042  }
1043 };
1044 
1045 // A trivial wrapper to help generate different operations for dense/sparse
1046 // tensors.
1047 struct TensorLike {
1048  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1049  ValueRange sizes) {
1050  SmallVector<Value> dynSzs;
1051  getDynamicSizes(rtt, sizes, dynSzs);
1052 
1053  val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
1054  if (!isSparse()) {
1055  Value c0 = constantZero(builder, loc, rtt.getElementType());
1056  val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
1057  }
1058  }
1059 
1060  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1061  val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1062  }
1063 
1064  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1065  if (isSparse())
1066  return builder.create<LoadOp>(loc, val, true);
1067  return val;
1068  }
1069 
1070  bool isSparse() const {
1071  return getSparseTensorEncoding(val.getType()) != nullptr;
1072  }
1073 
1074  Value val;
1075 };
1076 
1077 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1079  LogicalResult matchAndRewrite(tensor::DimOp op,
1080  PatternRewriter &rewriter) const override {
1081  std::optional<int64_t> dim = op.getConstantIndex();
1082  auto stt = getSparseTensorType(op.getSource());
1083  if (!dim || !stt.hasEncoding())
1084  return failure();
1085 
1086  if (stt.isPermutation()) {
1087  rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1088  toLvl(stt.getEncoding(), *dim));
1089  return success();
1090  }
1091 
1092  // Non-permutation dim2lvl/lvl2dim maps.
1093  // Compute as follows:
1094  // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1095  // Note that it is not the most efficient way (but a more general one) for
1096  // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1097  // computed simply by lvl_size * block_size.
1098  Location loc = op.getLoc();
1099  SmallVector<Value> maxLvlCrds;
1100  for (Level l = 0; l < stt.getLvlRank(); l++) {
1101  Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1102  Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1103  loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1104  maxLvlCrds.push_back(maxLvlCrd);
1105  }
1106 
1107  AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
1108  Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1109  op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
1110  maxLvlCrds);
1111 
1112  Value dimSz = rewriter.create<arith::AddIOp>(
1113  loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1114  rewriter.replaceOp(op, dimSz);
1115  return success();
1116  }
1117 };
1118 
1119 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1121  LogicalResult matchAndRewrite(ConcatenateOp op,
1122  PatternRewriter &rewriter) const override {
1123  if (op.needsExtraSort())
1124  op.emitError("ConcatenateOp not staged");
1125 
1126  const Location loc = op.getLoc();
1127  const auto dstTp = getSparseTensorType(op);
1128  const Dimension conDim = op.getDimension();
1129  SmallVector<Value> sizes;
1130  concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1131 
1132  // %t = concatenate %s1, %s2, %s3 {dim = 1}
1133  // ==>
1134  // if (isSparseDst)
1135  // if (allDense)
1136  // %tmp = bufferization.alloc_tensor dstTp
1137  // else
1138  // %tmp = bufferization.alloc_tensor : unordered COO
1139  // else
1140  // %tmp = memref.alloc : dense tensor
1141  // foreach in %s1 : insert d0, d1, %tmp
1142  // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1143  // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1144 
1145  TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1146  Value offset = constantIndex(rewriter, loc, 0);
1147  Value iterArg = dstBuf.val;
1148 
1149  ForeachOp foreachOp;
1150  for (Value input : op.getInputs()) {
1151  // Builds a for op for each input tensor to append new values into the
1152  // output tensor.
1153  foreachOp = rewriter.create<ForeachOp>(
1154  loc, input, iterArg,
1155  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1156  ValueRange reduc) {
1157  SmallVector<Value> offDimCrd(dcvs);
1158  offDimCrd[conDim] =
1159  builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1160 
1161  // Enters foreach, updates the SSA chain.
1162  dstBuf.val = reduc.front();
1163  if (!dstTp.isAllDense()) {
1164  Value cond = genIsNonzero(builder, loc, v);
1165  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1166  /*else*/ true);
1167  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1168  builder.create<scf::YieldOp>(loc, dstBuf.val);
1169 
1170  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1171  dstBuf.insert(builder, loc, v, offDimCrd);
1172  builder.create<scf::YieldOp>(loc, dstBuf.val);
1173 
1174  // Exits the ifOp, update the sparse tensor SSA value.
1175  builder.setInsertionPointAfter(ifOp);
1176  dstBuf.val = ifOp.getResult(0);
1177  } else {
1178  dstBuf.insert(builder, loc, v, offDimCrd);
1179  }
1180  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1181  });
1182  // Accumulates the offset. Note that only static-shaped inputs are allowed
1183  // by concatenate op verifier, which saves us from computing the offset
1184  // dynamically.
1185  const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1186  assert(!ShapedType::isDynamic(sz));
1187  offset = rewriter.create<arith::AddIOp>(loc, offset,
1188  constantIndex(rewriter, loc, sz));
1189  iterArg = foreachOp.getResult(0);
1190  dstBuf.val = iterArg;
1191  }
1192 
1193  dstBuf.val = iterArg;
1194  Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1195  rewriter.replaceOp(op, ret);
1196  return success();
1197  }
1198 };
1199 
1200 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1202  LogicalResult matchAndRewrite(ConvertOp op,
1203  PatternRewriter &rewriter) const override {
1204  if (op.needsExtraSort())
1205  return op.emitError("ConvertOp not staged.");
1206 
1207  // TODO: Maybe we want a different operation for this too.
1208  auto encDst = getSparseTensorEncoding(op.getType());
1209  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1210  if (encDst && encSrc && !encSrc.isSlice() &&
1211  encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1212  // Trivial tensor conversion and simple element type conversion is handled
1213  // in codegen.
1214  return failure();
1215  }
1216 
1217  Location loc = op.getLoc();
1218  Value src = op.getSource();
1219 
1220  SparseTensorType srcStt = getSparseTensorType(op.getSource());
1221  SparseTensorType dstStt = getSparseTensorType(op.getDest());
1222 
1223  bool fromSparseConst = false;
1224  if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1225  if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1226  fromSparseConst = true;
1227 
1228  const AffineMapAttr foreachOrder =
1229  (!dstStt.isIdentity() && fromSparseConst)
1231  : nullptr;
1232 
1233  bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1234 
1235  SmallVector<Value> sizes;
1236  sizesFromSrc(rewriter, sizes, loc, src);
1237  ValueRange vs;
1238  TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1239 
1240  auto foreachOp = rewriter.create<ForeachOp>(
1241  loc, src, dstBuf.val, foreachOrder,
1242  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1243  ValueRange reduc) {
1244  // Enters the loop, update the SSA value for insertion chain.
1245  dstBuf.val = reduc.front();
1246  if (!skipZeroCheck) {
1247  Value cond = genIsNonzero(builder, loc, v);
1248  auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1249  /*else*/ true);
1250  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1251  builder.create<scf::YieldOp>(loc, dstBuf.val);
1252 
1253  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1254  dstBuf.insert(builder, loc, v, dcvs);
1255  builder.create<scf::YieldOp>(loc, dstBuf.val);
1256 
1257  // Exits the ifOp, update the sparse tensor SSA value.
1258  builder.setInsertionPointAfter(ifOp);
1259  dstBuf.val = ifOp.getResult(0);
1260  } else {
1261  dstBuf.insert(builder, loc, v, dcvs);
1262  }
1263  builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1264  });
1265 
1266  rewriter.setInsertionPointAfter(foreachOp);
1267 
1268  // Exits the for loop, links the SSA chain.
1269  dstBuf.val = foreachOp.getResult(0);
1270 
1271  Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1272  rewriter.replaceOp(op, ret);
1273  return success();
1274  }
1275 };
1276 
1277 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1279  LogicalResult matchAndRewrite(CrdTranslateOp op,
1280  PatternRewriter &rewriter) const override {
1281  AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1282  ? op.getEncoder().getDimToLvl()
1283  : op.getEncoder().getLvlToDim();
1284 
1285  SmallVector<Value> outCrds;
1286  for (AffineExpr result : map.getResults()) {
1287  // TODO: we should probably expand the affine map to IR using our own
1288  // rules, since affine.apply assume signed value, while the cooridinates
1289  // we provided must always be signless.
1290  Value trans = rewriter.create<affine::AffineApplyOp>(
1291  op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1292  op.getInCrds());
1293  outCrds.push_back(trans);
1294  }
1295  rewriter.replaceOp(op, outCrds);
1296  return success();
1297  }
1298 };
1299 
1300 /// Sparse rewriting rule for the foreach operator.
1301 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1302 public:
1304 
1305  LogicalResult matchAndRewrite(ForeachOp op,
1306  PatternRewriter &rewriter) const override {
1307 
1308  auto loc = op.getLoc();
1309  Value input = op.getTensor();
1310  SmallVector<Value> reduc = op.getInitArgs();
1311  const auto stt = getSparseTensorType(input);
1312  const Level lvlRank = stt.getLvlRank();
1313 
1314  // Special-case: for each over a sparse constant uses its own rewriting
1315  // rule.
1316  if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1317  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1318  return genForeachOnSparseConstant(op, rewriter, attr);
1319  }
1320  }
1321 
1322  // Otherwise, use loop emitter to generate loops.
1323  const auto enc = stt.getEncoding();
1324 
1325  // 1. Generates loop for the sparse input.
1326  LoopEmitter loopEmitter(
1327  ValueRange{input},
1328  StringAttr::get(getContext(), ForeachOp::getOperationName()));
1329  loopEmitter.initializeLoopEmit(rewriter, loc);
1330  for (Level l = 0; l < lvlRank; l++) {
1331  // TODO: provide utility function for loop sequences that only contains
1332  // one for loop?
1333  const SmallVector<TensorLevel, 1> tidLvls{
1334  loopEmitter.makeTensorLevel(0, l)};
1335  loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1336  // Note that reduc will be taken care of by loop emitter and get updated
1337  // in place.
1338  loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1339  reduc);
1340  }
1341 
1342  SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1343  if (op.getOrder()) {
1344  // TODO: Support it so that we can do direct conversion from CSR->BSR.
1345  llvm_unreachable(
1346  "Level order not yet implemented on non-constant input tensors.");
1347  }
1348 
1349  Value vals = loopEmitter.getValBuffer()[0];
1350  SmallVector<Value> pos = loopEmitter.getValPosits(0);
1351  // Loads the value from sparse tensor using position-index;
1352  // loads the value from dense tensor using coords.
1353  Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1354  : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1355 
1356  // 2. Inline the block in the foreach operator.
1357  Block *srcBlock = op.getBody();
1358 
1359  // Remap coordinates.
1360  SmallVector<Value> args =
1361  enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1362 
1363  // Remap value.
1364  args.push_back(val);
1365  // Remap reduction variables.
1366  args.append(reduc);
1367 
1368  // Remove sparse_tensor.yield.
1369  SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1370  rewriter.eraseOp(srcBlock->getTerminator());
1371 
1372  Operation &last = rewriter.getBlock()->back();
1373  if (llvm::isa<scf::YieldOp>(last)) {
1374  // Because `scf.for` inserts an implicit yield op when there is no
1375  // reduction variable upon creation, we reset the insertion point such
1376  // that the block is inlined before *before* the yield op.
1377  rewriter.setInsertionPoint(&last);
1378  }
1379 
1380  rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1381  rewriter.getInsertionPoint(), args);
1382  rewriter.setInsertionPointToEnd(rewriter.getBlock());
1383  for (Level l = 0; l < lvlRank; l++) {
1384  // Link the reduction chain. Note that loop emitter update the reducValue
1385  // in place.
1386  loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1387  loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1388  }
1389 
1390  // Replace the foreach operator with the value returned by the outtermost
1391  // for loop.
1392  rewriter.replaceOp(op, reducValue);
1393  return success();
1394  }
1395 };
1396 
1397 /// Sparse rewriting rule for the new operator.
1398 struct NewRewriter : public OpRewritePattern<NewOp> {
1400  LogicalResult matchAndRewrite(NewOp op,
1401  PatternRewriter &rewriter) const override {
1402  Location loc = op.getLoc();
1403  auto stt = getSparseTensorType(op.getResult());
1404  if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1405  return failure();
1406 
1407  // Implement the NewOp as follows:
1408  // %orderedCoo = sparse_tensor.new %filename
1409  // %t = sparse_tensor.convert %orderedCoo
1410  // with enveloping reinterpreted_map ops for non-permutations.
1411  RankedTensorType dstTp = stt.getRankedTensorType();
1412  RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1413  Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1414  Value convert = cooTensor;
1415  auto enc = stt.getEncoding();
1416  if (!stt.isPermutation()) { // demap coo, demap dstTp
1417  auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1418  convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1419  dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1420  }
1421  convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1422  if (!stt.isPermutation()) // remap to original enc
1423  convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1424  rewriter.replaceOp(op, convert);
1425 
1426  // Release the temporary ordered COO tensor.
1427  rewriter.setInsertionPointAfterValue(convert);
1428  rewriter.create<DeallocTensorOp>(loc, cooTensor);
1429 
1430  return success();
1431  }
1432 };
1433 
1434 /// Sparse rewriting rule for the out operator.
1435 struct OutRewriter : public OpRewritePattern<OutOp> {
1437  LogicalResult matchAndRewrite(OutOp op,
1438  PatternRewriter &rewriter) const override {
1439  Location loc = op.getLoc();
1440  // Calculate NNZ.
1441  Value src = op.getTensor();
1442  Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1443 
1444  // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1445  const auto srcTp = getSparseTensorType(src);
1446  const Dimension dimRank = srcTp.getDimRank();
1447  Type indexTp = rewriter.getIndexType();
1448  Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1449 
1450  // Generate code to calculate dimension size values and store the values to
1451  // the buffer.
1452  SmallVector<Value> dims;
1453  sizesForTensor(rewriter, dims, loc, srcTp, src);
1454  for (Dimension d = 0; d < dimRank; d++) {
1455  rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1456  constantIndex(rewriter, loc, d));
1457  }
1458 
1459  // Create a sparse tensor writer and output meta data.
1460  Type opaqueTp = getOpaquePointerType(rewriter);
1461  Value writer =
1462  createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1463  {op.getDest()}, EmitCInterface::Off)
1464  .getResult(0);
1465  Value rankValue = constantIndex(rewriter, loc, dimRank);
1466  createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1467  {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1468 
1469  Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1470  Type eltTp = srcTp.getElementType();
1471  SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1472  primaryTypeFunctionSuffix(eltTp)};
1473  Value value = genAllocaScalar(rewriter, loc, eltTp);
1474  ModuleOp module = op->getParentOfType<ModuleOp>();
1475 
1476  // For each element in the source tensor, output the element.
1477  rewriter.create<ForeachOp>(
1478  loc, src, std::nullopt,
1479  [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1480  ValueRange reduc) {
1481  for (Dimension d = 0; d < dimRank; d++) {
1482  rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1483  constantIndex(builder, loc, d));
1484  }
1485  rewriter.create<memref::StoreOp>(loc, v, value);
1486  SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1487  FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1488  EmitCInterface::On);
1489  builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1490  builder.create<sparse_tensor::YieldOp>(loc);
1491  });
1492 
1493  // Release the writer.
1494  createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1495  EmitCInterface::Off);
1496 
1497  rewriter.eraseOp(op);
1498  return success();
1499  }
1500 };
1501 
1502 } // namespace
1503 
1504 //===---------------------------------------------------------------------===//
1505 // Methods that add patterns described in this file to a pattern list.
1506 //===---------------------------------------------------------------------===//
1507 
1509  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
1510  FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
1511  GenSemiRingSelect, PrintRewriter>(patterns.getContext());
1512 }
1513 
1515  bool enableRT,
1516  bool enableConvert) {
1517  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1518  ReshapeRewriter<tensor::CollapseShapeOp>,
1519  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1520  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1521  SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1522  patterns.getContext());
1523 
1524  if (enableConvert)
1525  patterns.add<DirectConvertRewriter>(patterns.getContext());
1526  if (!enableRT)
1527  patterns.add<NewRewriter>(patterns.getContext());
1528 }
1529 
1531  // Run CrdTranslateRewriter later in the pipeline so that operation can be
1532  // folded before lowering to affine.apply
1533  patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1534 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation & back()
Definition: Block.h:149
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
auto m_Any()
Definition: Matchers.h:448
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:323
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:74
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:93
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:119
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:340
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238