MLIR 22.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensors.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenUtils.h"
14#include "Utils/LoopEmitter.h"
15
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/Support/LLVM.h"
32
33using namespace mlir;
34using namespace mlir::bufferization;
35using namespace mlir::linalg;
36using namespace mlir::sparse_tensor;
37
38//===---------------------------------------------------------------------===//
39// Helper methods for the actual rewriting rules.
40//===---------------------------------------------------------------------===//
41
42// Helper method to match any typed zero.
43static bool isZeroValue(Value val) {
44 return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45}
46
47// Helper to detect a sparse tensor type operand.
48static bool isSparseTensor(Value v) {
49 auto enc = getSparseTensorEncoding(v.getType());
50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](auto lt) { return lt == LevelFormat::Dense; });
52}
53static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54
55// Helper method to find zero/uninitialized tensor materialization.
56static bool isMaterializing(OpOperand *op, bool isZero) {
57 Value val = op->get();
58 // Check allocation, with zero alloc when required.
59 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60 Value copy = alloc.getCopy();
61 if (isZero)
62 return copy && isZeroValue(copy);
63 return !copy;
64 }
65 // Check for empty tensor materialization.
66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67 return !isZero;
68 // Last resort for zero alloc: the whole value is zero.
69 return isZero && isZeroValue(val);
70}
71
72// Helper to detect sampling operation.
73static bool isSampling(GenericOp op) {
74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77 // Both scalar input arguments used exactly once.
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82 }
83 }
84 return false;
85}
86
87// Helper to detect chain of multiplications that do not involve x.
88static bool isMulChain(Value val, Value x) {
89 if (auto arg = dyn_cast<BlockArgument>(val))
90 return arg != x;
91 if (auto *def = val.getDefiningOp()) {
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93 return isMulChain(def->getOperand(0), x) &&
94 isMulChain(def->getOperand(1), x);
95 }
96 return false;
97}
98
99// Helper to detect x = x + <multiplications>.
100static bool isSumOfMul(GenericOp op) {
101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107 }
108 }
109 return false;
110}
111
112// Helper to detect direct yield of a zero value.
113static bool isZeroYield(GenericOp op) {
114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
118 }
119 }
120 return isZeroValue(yieldOp.getOperand(0));
121}
122
123/// Populates given sizes array from type (for static sizes) and from
124/// the tensor (for dynamic sizes).
126 Location loc, ShapedType stp, Value tensor) {
127 for (const auto &d : enumerate(stp.getShape())) {
128 Value dim;
129 if (d.value() == ShapedType::kDynamic)
130 dim = tensor::DimOp::create(builder, loc, tensor, d.index());
131 else
132 dim = constantIndex(builder, loc, d.value());
133 sizes.push_back(dim);
134 }
135}
136
137static RankedTensorType getBufferType(const SparseTensorType &stt,
138 bool needTmpCOO) {
139 return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140 : stt.getRankedTensorType();
141}
142
143/// Collects the dynamic dimension sizes for `tp` with the assumption that
144/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145/// sizes to dynSizes.
146static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147 SmallVectorImpl<Value> &dynSizes) {
148 for (const auto &d : enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
151 }
152}
153
154static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155 RewriterBase &rewriter,
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
158 SmallVector<Value> reduc = op.getInitArgs();
159
160 // Foreach on constant.
162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
165 args.append(cvs.begin(), cvs.end());
166 args.push_back(v);
167 args.append(reduc);
168 // Clones the foreach op to get a copy of the loop body.
169 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
172 rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173 // clean up
174 rewriter.eraseOp(cloned);
175 reduc = yield->getOperands();
176 rewriter.eraseOp(yield);
177 });
178
179 rewriter.replaceOp(op, reduc);
180 return success();
181}
182
183/// Populates the given sizes array for concatenation from types (for static
184/// sizes) and from the source tensors (for dynamic sizes).
185static void concatSizesFromInputs(OpBuilder &builder,
187 ShapedType dstTp, ValueRange srcs,
188 unsigned dim) {
189 auto dstShape = dstTp.getShape();
190 sizesFromSrc(builder, sizes, loc, srcs[0]);
191
192 // Sum up on the `dim` if the dimension is dynamic.
193 if (dstShape[dim] != ShapedType::kDynamic) {
194 // Faithfully take the static size.
195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196 } else {
197 // Else, compute the shape dynamically.
198 for (const auto &src : srcs.drop_front()) {
199 Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200 // Sum up all the sizes.
201 sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz);
202 }
203 }
204}
205
206//===---------------------------------------------------------------------===//
207// The actual sparse tensor rewriting rules.
208//===---------------------------------------------------------------------===//
209
210namespace {
211
212/// TODO: move it to tensor dialect instead.
213///
214/// Fold `tensor.concat` and `tensor.extract_slice`
215///
216/// %concat = tensor.concat dim(2) %t0, %t1
217/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222///
223/// Becomes
224///
225/// %extract0, %extract1 = %t0, %t1
226struct FuseExtractSliceWithConcat
227 : public OpRewritePattern<tensor::ExtractSliceOp> {
228 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
229
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231 PatternRewriter &rewriter) const override {
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233 if (!concatOp)
234 return failure();
235
236 Location loc = extractOp.getLoc();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
239
240 SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241 SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242
243 // Compute the partial sums for the slice offsets.
244 AffineExpr sum = rewriter.getAffineDimExpr(0);
245 SmallVector<AffineExpr> partialSums = {sum};
246 SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247 for (auto [idx, input] :
248 llvm::enumerate(concatOp.getInputs().drop_back())) {
249 sum = sum + rewriter.getAffineDimExpr(idx + 1);
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
252 rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253 }
254 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255 partialSums, rewriter.getContext());
256 SmallVector<OpFoldResult> dimOffsets =
258 rewriter, loc, partialSumMap, offsetStrides);
259
260 auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261 for (auto [l, r] : llvm::zip(lhs, rhs)) {
262 std::optional<int64_t> staticVal = getConstantIntValue(l);
263 if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264 return false;
265 }
266 return lhs.size() == rhs.size();
267 };
268
269 for (auto [i, input, offset] :
270 llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271 SmallVector<OpFoldResult> srcSizes =
272 tensor::getMixedSizes(rewriter, loc, input);
273 srcOffsets[dim] = offset;
274
275 SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276 SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277 SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.getType() == extractOp.getResultType())
283 rewriter.replaceOp(extractOp, operand);
284 break;
285 }
286 }
287
288 return success();
289 }
290};
291
292/// Rewriting rule that fuses sparse_tensor.convert into producer.
293struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294public:
296
297 LogicalResult matchAndRewrite(ConvertOp op,
298 PatternRewriter &rewriter) const override {
299 auto producer = op.getSource().getDefiningOp<GenericOp>();
300 if (!producer || producer.getDpsInits().size() != 1 ||
301 !isMaterializing(producer.getDpsInitOperand(0), false) ||
302 !producer.getResult(0).hasOneUse()) {
303 return failure();
304 }
305 // Clone the materialization operation, but update the result to sparse.
306 rewriter.setInsertionPoint(producer);
307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308 Operation *cloned = rewriter.clone(*init);
309 cloned->getResult(0).setType(op.getResult().getType());
310
311 rewriter.modifyOpInPlace(producer, [&]() {
312 producer.getDpsInitsMutable().assign(cloned->getResults());
313 producer.getResult(0).setType(op.getResult().getType());
314 });
315
316 rewriter.replaceAllOpUsesWith(op, producer);
317 op->erase();
318
319 return success();
320 }
321};
322
323/// Rewriting rule that converts direct yield of zero with initial allocation.
324struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325public:
326 using OpRewritePattern<GenericOp>::OpRewritePattern;
327
328 LogicalResult matchAndRewrite(GenericOp op,
329 PatternRewriter &rewriter) const override {
330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333 return failure();
334 auto outputType = getRankedTensorType(op.getResult(0));
335 // Yielding zero on newly materialized sparse tensor can be
336 // optimized directly (regardless of dynamic or static size).
337 if (getSparseTensorEncoding(outputType)) {
338 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339 return success();
340 }
341 // Use static zero value directly instead of materialization.
342 if (!outputType.hasStaticShape())
343 return failure();
344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346 rewriter.eraseOp(def);
347 return success();
348 }
349};
350
351/// Rewriting rule that converts two kernels:
352///
353/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354/// X(i,j) = S(i,j) * T(i,j)
355///
356/// into a single kernel, using distributive law:
357///
358/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
359///
360/// This kind of fusion (merging two ops into one but using arithmetic
361/// equalities that may not hold for floating-point computations) would
362/// be undesirable in the dense case, since we distribute the multiplication
363/// into the reduction loop. However, for sparse sampling tensor S, such
364/// a fusion may actually reduce the asymptotic complexity of the kernel,
365/// since intermediate results may be nullified.
366struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367public:
368 using OpRewritePattern<GenericOp>::OpRewritePattern;
369
370 LogicalResult matchAndRewrite(GenericOp op,
371 PatternRewriter &rewriter) const override {
372 // Check consumer.
373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374 op.getNumResults() != 1 ||
375 op.getNumParallelLoops() != op.getNumLoops() ||
376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379 return failure();
380 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381 // operand can be sparse or dense, since the point of this rewriting rule
382 // is detecting a situation in which *more* sparsity is introduced into
383 // a computation, be it already sparse or still dense.
384 unsigned other = 0;
385 if (isSparseTensor(op.getDpsInputOperand(0)))
386 other = 1;
387 else if (!isSparseTensor(op.getDpsInputOperand(1)))
388 return failure();
389 // Check producer.
390 auto prod = dyn_cast_or_null<GenericOp>(
391 op.getDpsInputOperand(other)->get().getDefiningOp());
392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393 !prod.getResult(0).hasOneUse())
394 return failure();
395 // Sampling consumer and sum of multiplication chain producer.
396 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398 !isSampling(op) || !isSumOfMul(prod))
399 return failure();
400 // Modify operand structure of producer and consumer.
401 Location loc = prod.getLoc();
402 SmallVector<Value> inputOps = prod.getInputs();
403 SmallVector<Value> outputOps = op.getOutputs();
404 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
407 // Fuse producer and consumer into a new generic op.
408 auto fusedOp = GenericOp::create(
409 rewriter, loc, op.getResult(0).getType(), inputOps, outputOps,
410 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411 /*doc=*/nullptr, /*library_call=*/nullptr);
412 Block &prodBlock = prod.getRegion().front();
413 Block &consBlock = op.getRegion().front();
414 IRMapping mapper;
415 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416 unsigned num = prodBlock.getNumArguments();
417 for (unsigned i = 0; i < num - 1; i++)
418 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421 // Clone bodies of the producer and consumer in new evaluation order.
422 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
423 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
424 Value last;
425 for (auto &op : prodBlock.without_terminator())
426 if (&op != acc) {
427 last = op.getResult(0);
428 rewriter.clone(op, mapper);
429 }
430 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
431 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432 last = rewriter.clone(*acc, mapper)->getResult(0);
433 linalg::YieldOp::create(rewriter, loc, last);
434 // Force initial value on merged allocation for dense outputs.
435 // TODO: deal with non alloc tensor here one day
436 if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437 Value init = prod.getDpsInitOperand(0)
438 ->get()
439 .getDefiningOp<AllocTensorOp>()
440 .getCopy();
441 AllocTensorOp a =
442 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444 }
445 // Replace consumer with fused operation. Old producer
446 // and consumer ops will be removed by DCE.
447 rewriter.replaceOp(op, fusedOp->getResults());
448 return success();
449 }
450
451private:
452 // Helper to add argument and record the mapping.
453 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454 mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
455 }
456};
457
458// Fuse a tensor cast into producing operation. Note that a tensor.cast
459// should really not be used to convert between sparse encodings. Since
460// the pattern currently appears as a result of some prior rewriting
461// we make an attempt to repair very obvious cases.
462// TODO: audit the pure tensor dialect rewriting rules
463struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464public:
465 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
466
467 LogicalResult matchAndRewrite(tensor::CastOp op,
468 PatternRewriter &rewriter) const override {
469 Type srcType = op.getSource().getType();
470 Type dstType = op.getDest().getType();
471 // A nop cast simply folds away.
472 if (srcType == dstType) {
473 rewriter.replaceOp(op, op->getResults());
474 return success();
475 }
476 // See if a sparsity changing cast can be fused into producer.
477 if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
478 if (Operation *def = op.getSource().getDefiningOp()) {
479 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
480 rewriter.modifyOpInPlace(def, [&]() {
481 def->getResult(0).setType(op->getResultTypes()[0]);
482 });
483 rewriter.replaceOp(op, def->getResult(0));
484 return success();
485 }
486 }
487 }
488 // Repair tensor casts with at least one sparse operand into the
489 // the properly supported sparse_tensor.convert.
490 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492 return success();
493 }
494 // Fail otherwise.
495 return failure();
496 }
497};
498
499/// Rewrites a sequence of operations for sparse tensor selections in to
500/// semi-ring operations such that they can be compiled correctly by the
501/// sparsifier. E.g., transforming the following sequence
502///
503/// %sel = arith.select %cond, %sp1, %sp2
504///
505/// to
506///
507/// %sel = binary %sp1, %sp2:
508/// both (%l, %r) {yield select %cond, %l, %r}
509/// left (%l) {yield select %cond, %l, 0}
510/// right (%r) {yield select %cond, 0, %r}
511///
512/// TODO: We require that the tensor used for extracting conditions to be dense
513/// to sparsify the code. To support a sparse condition tensor, we need a
514/// tri-nary operation.
515struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516public:
517 using OpRewritePattern<GenericOp>::OpRewritePattern;
518 LogicalResult matchAndRewrite(GenericOp op,
519 PatternRewriter &rewriter) const override {
520 // Rejects non sparse kernels.
521 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522 return failure();
523
524 Location loc = op.getLoc();
525 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
526 for (Operation &inst : *op.getBody()) {
527 // Matches pattern.
528 auto matched = isRewritablePattern(op, &inst);
529 if (!matched.has_value())
530 continue;
531
532 rewriter.setInsertionPoint(&inst);
533 auto [c, t, f] = matched.value();
534 assert(t.getType() == f.getType());
535 auto selTp = t.getType();
536 auto c0 = constantZero(rewriter, loc, selTp);
537 auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f);
538 // Initializes all the blocks.
539 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540 {t.getLoc(), f.getLoc()});
541 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543
544 for (auto *r : binOp.getRegions()) {
545 Block *b = &r->front();
546 rewriter.setInsertionPointToStart(b);
547
548 IRMapping irMap;
549 // Clones the cmp operations into the region to make the binary op
550 // admissible.
551 Value newC = c;
552 if (auto *def = c.getDefiningOp())
553 newC = rewriter.clone(*def, irMap)->getResult(0);
554
555 irMap.map(c, newC);
556 if (r == &binOp.getLeftRegion()) {
557 irMap.map(t, b->getArgument(0));
558 irMap.map(f, c0);
559 } else if (r == &binOp.getRightRegion()) {
560 irMap.map(t, c0);
561 irMap.map(f, b->getArgument(0));
562 } else {
563 irMap.map(t, b->getArgument(0));
564 irMap.map(f, b->getArgument(1));
565 }
566 auto y = rewriter.clone(inst, irMap)->getResult(0);
567 sparse_tensor::YieldOp::create(rewriter, loc, y);
568 }
569
570 // We successfully rewrited a operation. We can not do replacement here
571 // becuase it invalidate the iterator for the current loop to traverse
572 // the instructions.
573 semiRings.emplace_back(&inst, binOp);
574 }
575
576 // Finalizes the replacement.
577 for (auto [sel, semi] : semiRings)
578 rewriter.replaceOp(sel, semi->getResults());
579
580 return success(!semiRings.empty());
581 }
582
583private:
584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585 isRewritablePattern(GenericOp op, Operation *v) {
586 auto sel = dyn_cast<arith::SelectOp>(v);
587 if (!sel)
588 return std::nullopt;
589
590 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592 // TODO: For simplicity, we only handle cases where both true/false value
593 // are directly loaded the input tensor. We can probably admit more cases
594 // in theory.
595 if (!tVal || !fVal)
596 return std::nullopt;
597
598 // Helper lambda to determine whether the value is loaded from a dense input
599 // or is a loop invariant.
600 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601 if (auto bArg = dyn_cast<BlockArgument>(v);
602 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603 return true;
604 // If the value is defined outside the loop, it is a loop invariant.
605 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606 };
607
608 // If the condition value is load directly from a dense tensor or
609 // loop-invariants, we can sparsify the kernel.
610 auto cond = sel.getCondition();
611 if (isValFromDenseInputOrInvariant(cond))
612 return std::make_tuple(cond, tVal, fVal);
613
614 Value cmpL, cmpR;
616 matchers::m_Any(&cmpR))) ||
618 matchers::m_Any(&cmpR)))) {
619 // TODO: we can do it recursively to check whether all the leaf values are
620 // loaded from dense tensors or are loop invariants.
621 if (isValFromDenseInputOrInvariant(cmpL) ||
622 isValFromDenseInputOrInvariant(cmpR))
623 return std::make_tuple(cond, tVal, fVal);
624 }
625
626 return std::nullopt;
627 };
628};
629
630/// Rewrites a sparse reduction that would not sparsify directly since
631/// doing so would only iterate over the stored elements, ignoring the
632/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633/// (note that reductions like add/sub/or/xor can directly be sparsified
634/// since the implicit zeros do not contribute to the final result).
635/// Note that prod/and are still included since, even though they often
636/// are nullified in sparse data, they may still occur for special
637/// situations in which e.g. some rows in a sparse matrix are fully
638/// dense. For min/max, including the implicit zeros is a much more
639/// common situation.
640///
641/// TODO: this essentially "densifies" the operation; we want to implement
642/// this much more efficiently by performing the reduction over the
643/// stored values, and feed in the zero once if there were *any*
644/// implicit zeros as well; but for now, at least we provide
645/// the functionality
646///
647struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648public:
649 using OpRewritePattern<GenericOp>::OpRewritePattern;
650
651 LogicalResult matchAndRewrite(GenericOp op,
652 PatternRewriter &rewriter) const override {
653 // Reject non-reductions.
654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656 return failure();
657 auto *inp = op.getDpsInputOperand(0);
658 auto *init = op.getDpsInitOperand(0);
659 if (!isSparseTensor(inp))
660 return failure();
661 // Look for direct x = x OP y for semi-ring ready reductions.
662 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663 .getOperand(0)
664 .getDefiningOp();
665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667 arith::MaxUIOp>(red))
668 return failure();
669 Value s0 = op.getBlock()->getArgument(0);
670 Value s1 = op.getBlock()->getArgument(1);
671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673 return failure();
674 // Identity.
675 Location loc = op.getLoc();
676 Value identity =
677 tensor::ExtractOp::create(rewriter, loc, init->get(), ValueRange());
678 // Unary {
679 // present -> value
680 // absent -> zero.
681 // }
682 Type rtp = s0.getType();
683 rewriter.setInsertionPointToStart(&op.getRegion().front());
684 auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0);
685 Block *present =
686 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688 sparse_tensor::YieldOp::create(rewriter, loc, present->getArgument(0));
689 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691 auto zero =
692 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(rtp));
693 sparse_tensor::YieldOp::create(rewriter, loc, zero);
694 rewriter.setInsertionPointAfter(semiring);
695 // CustomReduce {
696 // x = x REDUC y, identity
697 // }
698 auto custom = sparse_tensor::ReduceOp::create(
699 rewriter, loc, rtp, semiring.getResult(), s1, identity);
700 Block *region =
701 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702 rewriter.setInsertionPointToStart(&custom.getRegion().front());
703 IRMapping irMap;
704 irMap.map(red->getOperand(0), region->getArgument(0));
705 irMap.map(red->getOperand(1), region->getArgument(1));
706 auto *cloned = rewriter.clone(*red, irMap);
707 sparse_tensor::YieldOp::create(rewriter, loc, cloned->getResult(0));
708 rewriter.setInsertionPointAfter(custom);
709 rewriter.replaceOp(red, custom.getResult());
710 return success();
711 }
712};
713
714/// Sparse rewriting rule for the print operator. This operation is mainly used
715/// for debugging and testing. As such, it lowers to the vector.print operation
716/// which only require very light-weight runtime support.
717struct PrintRewriter : public OpRewritePattern<PrintOp> {
718public:
720 LogicalResult matchAndRewrite(PrintOp op,
721 PatternRewriter &rewriter) const override {
722 Location loc = op.getLoc();
723 auto tensor = op.getTensor();
724 auto stt = getSparseTensorType(tensor);
725 // Header with NSE.
726 auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor);
727 vector::PrintOp::create(
728 rewriter, loc,
729 rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
730 vector::PrintOp::create(rewriter, loc, nse);
731 // Print run-time contents for dim/lvl sizes.
732 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("dim = "));
733 printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
734 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("lvl = "));
735 printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
736 // Use the "codegen" foreach loop construct to iterate over
737 // all typical sparse tensor components for printing.
738 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
739 &stt](Type, FieldIndex,
741 Level l, LevelType) {
742 switch (kind) {
743 case SparseTensorFieldKind::StorageSpec: {
744 break;
745 }
746 case SparseTensorFieldKind::PosMemRef: {
747 auto lvl = constantIndex(rewriter, loc, l);
748 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("pos["));
749 vector::PrintOp::create(rewriter, loc, lvl,
750 vector::PrintPunctuation::NoPunctuation);
751 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : "));
752 auto pos = ToPositionsOp::create(rewriter, loc, tensor, l);
753 printContents(rewriter, loc, pos);
754 break;
755 }
756 case SparseTensorFieldKind::CrdMemRef: {
757 auto lvl = constantIndex(rewriter, loc, l);
758 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("crd["));
759 vector::PrintOp::create(rewriter, loc, lvl,
760 vector::PrintPunctuation::NoPunctuation);
761 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : "));
762 Value crd = nullptr;
763 // For COO AoS storage, we want to print a single, linear view of
764 // the full coordinate storage at this level. For any other storage,
765 // we show the coordinate storage for every indivual level.
766 if (stt.getAoSCOOStart() == l)
767 crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor);
768 else
769 crd = ToCoordinatesOp::create(rewriter, loc, tensor, l);
770 printContents(rewriter, loc, crd);
771 break;
772 }
773 case SparseTensorFieldKind::ValMemRef: {
774 vector::PrintOp::create(rewriter, loc,
775 rewriter.getStringAttr("values : "));
776 auto val = ToValuesOp::create(rewriter, loc, tensor);
777 printContents(rewriter, loc, val);
778 break;
779 }
780 }
781 return true;
782 });
783 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("----\n"));
784 rewriter.eraseOp(op);
785 return success();
786 }
787
788private:
789 // Helper to print contents of a single memref. For "push_back" vectors,
790 // we assume that the previous getters for pos/crd/val have added a
791 // slice-to-size view to make sure we just print the size and not the
792 // full capacity.
793 //
794 // Generates code to print (1-dim or higher):
795 // ( a0, a1, ... )
796 static void printContents(PatternRewriter &rewriter, Location loc,
797 Value vec) {
798 auto shape = cast<ShapedType>(vec.getType()).getShape();
799 SmallVector<Value> idxs;
800 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
801 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
802 }
803
804 // Helper to the helper.
805 static void printContentsLevel(PatternRewriter &rewriter, Location loc,
806 Value vec, unsigned i, ArrayRef<int64_t> shape,
807 SmallVectorImpl<Value> &idxs) {
808 // Open bracket.
809 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
810 // Generate for loop.
811 auto zero = constantIndex(rewriter, loc, 0);
812 auto index = constantIndex(rewriter, loc, i);
813 auto size = memref::DimOp::create(rewriter, loc, vec, index);
814 auto step = constantIndex(rewriter, loc, 1);
815 auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step);
816 idxs.push_back(forOp.getInductionVar());
817 rewriter.setInsertionPointToStart(forOp.getBody());
818 if (i < shape.size() - 1) {
819 // Enter deeper loop nest.
820 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
821 } else {
822 // Actual contents printing.
823 auto val = memref::LoadOp::create(rewriter, loc, vec, idxs);
824 if (llvm::isa<ComplexType>(val.getType())) {
825 // Since the vector dialect does not support complex types in any op,
826 // we split those into (real, imag) pairs here.
827 Value real = complex::ReOp::create(rewriter, loc, val);
828 Value imag = complex::ImOp::create(rewriter, loc, val);
829 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
830 vector::PrintOp::create(rewriter, loc, real,
831 vector::PrintPunctuation::Comma);
832 vector::PrintOp::create(rewriter, loc, imag,
833 vector::PrintPunctuation::Close);
834 } else {
835 vector::PrintOp::create(rewriter, loc, val,
836 vector::PrintPunctuation::NoPunctuation);
837 }
838 // Terminating comma (except at end).
839 auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step);
840 Value cond = arith::CmpIOp::create(rewriter, loc,
841 arith::CmpIPredicate::ne, bound, size);
842 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond, /*else*/ false);
843 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
844 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma);
845 }
846 idxs.pop_back();
847 rewriter.setInsertionPointAfter(forOp);
848 // Close bracket.
849 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
850 }
851
852 // Helper method to print run-time lvl/dim sizes.
853 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
854 unsigned size, bool isDim) {
855 // Open bracket.
856 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
857 // Print unrolled contents (dimop requires constant value).
858 for (unsigned i = 0; i < size; i++) {
859 auto idx = constantIndex(rewriter, loc, i);
860 Value val;
861 if (isDim)
862 val = tensor::DimOp::create(rewriter, loc, tensor, idx);
863 else
864 val = LvlOp::create(rewriter, loc, tensor, idx);
865 vector::PrintOp::create(rewriter, loc, val,
866 i != size - 1
867 ? vector::PrintPunctuation::Comma
868 : vector::PrintPunctuation::NoPunctuation);
869 }
870 // Close bracket and end of line.
871 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
872 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
873 }
874};
875
876/// Sparse rewriting rule for sparse-to-sparse reshape operator.
877struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
878public:
879 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
880
881 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
882 PatternRewriter &rewriter) const override {
883 Location loc = op.getLoc();
884 Value srcTensor = op.getSource();
885 const auto srcTp = tryGetSparseTensorType(srcTensor);
886 const auto dstTp = tryGetSparseTensorType(op.getResult());
887 if (!srcTp || !dstTp)
888 return failure();
889
890 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
891 !dstTp->hasStaticDimShape())
892 return failure();
893
894 SmallVector<Value> srcSizes;
895 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
896 SmallVector<Value> dstSizes;
897 for (Dimension d : dstTp->getDimShape())
898 dstSizes.push_back(constantIndex(rewriter, loc, d));
899
900 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
901 // Only need an unordered COO buffer if input and output are not sorted
902 // in the same way.
903 Type bufferTp = getBufferType(
904 dstTp->withoutDimToLvl(),
905 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
906 SmallVector<Value> dynSizes;
907 Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes,
908 Value(), nnz, Attribute())
909 .getResult();
910
911 // Convert src coordinates to dst coordinates by first collapsing it to 1D
912 // and then expand it to the match the rank of the destination tensor.
913 // Implemented as follows:
914 // foreach srcCoords %srcTensor
915 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917 // insert expandedCoords, %buffer
918 //
919 // followed by an optional
920 // %t = sparse_tensor.cast %tmp
921 // depending on whether the input/output are sorted in the same way.
922 const auto encSrc = srcTp->getEncoding();
923 ForeachOp foreachOp = ForeachOp::create(
924 rewriter, loc, srcTensor, buffer,
925 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
926 ValueRange reduc) {
927 const Dimension srcRank = srcTp->getDimRank();
928 SmallVector<Value> srcDcvs;
929 srcDcvs.reserve(srcRank);
930 for (Dimension d = 0; d < srcRank; d++) {
931 Level lvl = toLvl(encSrc, d);
932 srcDcvs.push_back(srcLcvs[lvl]);
933 }
934
935 Value collapseSize = constantIndex(builder, loc, 1);
936 for (Dimension d = 0; d < srcRank; d++)
937 collapseSize =
938 arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]);
939 SmallVector<Value, 1> collapsedSizes = {collapseSize};
940
941 ReassociationIndices collapseIdx;
942 for (Dimension i = 0; i < srcRank; i++)
943 collapseIdx.push_back(i);
944 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
945 SmallVector<Value, 1> collapsedDcvs;
946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947 collapsedSizes, collapsedDcvs);
948
949 ReassociationIndices expandIdx;
950 for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951 expandIdx.push_back(i);
952 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
953 SmallVector<Value> dstDcvs;
954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955 dstSizes, dstDcvs);
956
957 auto t =
958 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
959 sparse_tensor::YieldOp::create(builder, loc, t);
960 });
961
962 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true);
963 if (bufferTp != *dstTp) {
964 auto dstRTT = dstTp->getRankedTensorType();
965 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
966 DeallocTensorOp::create(rewriter, loc, t);
967 t = converted;
968 }
969 rewriter.replaceOp(op, t);
970 return success();
971 }
972};
973
974/// Sparse rewriting rule for sparse-to-sparse reshape operator.
975template <typename ReshapeOp>
976struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977public:
978 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
979
980 LogicalResult matchAndRewrite(ReshapeOp op,
981 PatternRewriter &rewriter) const override {
982 Location loc = op.getLoc();
983 Value srcTensor = op.getSrc();
984 const auto srcTp = getSparseTensorType(srcTensor);
985 const auto dstTp = getSparseTensorType(op.getResult());
986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987 return failure();
988
989 // Generate code to represent the static dimension constants or compute
990 // the dynamic dimension values.
991 SmallVector<Value> srcSizes;
992 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
993 SmallVector<Value> dstSizes;
994 SmallVector<Value> dstDynSizes;
995 if (dstTp.hasStaticDimShape()) {
996 for (Dimension d : dstTp.getDimShape())
997 dstSizes.push_back(constantIndex(rewriter, loc, d));
998 } else {
999 ArrayRef<Size> dstShape = dstTp.getDimShape();
1000 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001 op.getReassociationIndices());
1002 for (auto [idx, shape] : llvm::enumerate(dstShape)) {
1003 if (shape == ShapedType::kDynamic)
1004 dstDynSizes.push_back(dstSizes[idx]);
1005 }
1006 }
1007 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
1008 // Only need a unordered COO buffer if input and output are not sorted
1009 // in the same way.
1010 Type bufferTp = getBufferType(
1011 dstTp.withoutDimToLvl(),
1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013
1014 Value buffer =
1015 AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(),
1016 /*sizeHint=*/nnz, Attribute())
1017 .getResult();
1018
1019 // Implement the sparse2sparse reshape as follows:
1020 // foreach srcCoords %srcTensor
1021 // insert reshapeCvs(srcCoords), %buffer
1022 //
1023 // followed by an optional
1024 // %t = sparse_tensor.cast %tmp
1025 // depending on whether the input/output are sorted in the same way.
1026 const auto encSrc = srcTp.getEncoding();
1027 ForeachOp foreachOp = ForeachOp::create(
1028 rewriter, loc, srcTensor, buffer,
1029 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1030 ValueRange reduc) {
1031 const Dimension dimRank = srcTp.getDimRank();
1032 SmallVector<Value> srcDcvs;
1033 srcDcvs.reserve(dimRank);
1034 for (Dimension d = 0; d < dimRank; d++) {
1035 Level lvl = toLvl(encSrc, d);
1036 srcDcvs.push_back(srcLcvs[lvl]);
1037 }
1038 SmallVector<Value> dstDcvs;
1039 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1040 srcDcvs, dstSizes, dstDcvs);
1041 auto t =
1042 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
1043 sparse_tensor::YieldOp::create(builder, loc, t);
1044 });
1045
1046 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true);
1047 if (bufferTp != dstTp) {
1048 auto dstRTT = dstTp.getRankedTensorType();
1049 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
1050 DeallocTensorOp::create(rewriter, loc, t);
1051 t = converted;
1052 }
1053 rewriter.replaceOp(op, t);
1054 return success();
1055 }
1056};
1057
1058/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1059/// operator.
1060template <typename ReshapeOp>
1061struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1062public:
1063 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1064
1065 LogicalResult matchAndRewrite(ReshapeOp op,
1066 PatternRewriter &rewriter) const override {
1067 Location loc = op->getLoc();
1068 auto encDst = getSparseTensorEncoding(op.getResult().getType());
1069 auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1070 // Since a pure dense expansion is very cheap (change of view), for
1071 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1072 // conversion from the reshape operation itself.
1073 // All other cases are handled elsewhere.
1074 if (encDst && encSrc) {
1075 return failure();
1076 }
1077 if (encSrc) {
1078 auto rtp = getRankedTensorType(op.getSrc());
1079 auto denseTp =
1080 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1081 auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
1082 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1083 return success();
1084 }
1085 if (encDst) {
1086 auto rtp = getRankedTensorType(op.getResult());
1087 auto denseTp =
1088 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1089 ReshapeOp reshape;
1090 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1091 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1092 op.getReassociation(), op.getOutputShape(),
1093 op.getStaticOutputShape());
1094 } else {
1095 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1096 op.getReassociation());
1097 }
1098 Value convert = ConvertOp::create(rewriter, loc, rtp, reshape);
1099 rewriter.replaceOp(op, convert);
1100 return success();
1101 }
1102 return failure();
1103 }
1104};
1105
1106// A trivial wrapper to help generate different operations for dense/sparse
1107// tensors.
1108struct TensorLike {
1109 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1110 ValueRange sizes) {
1111 SmallVector<Value> dynSzs;
1112 getDynamicSizes(rtt, sizes, dynSzs);
1113
1114 val = AllocTensorOp::create(builder, loc, rtt, dynSzs);
1115 if (!isSparse()) {
1116 Value c0 = constantZero(builder, loc, rtt.getElementType());
1117 val = linalg::FillOp::create(builder, loc, c0, val).getResult(0);
1118 }
1119 }
1120
1121 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1122 val = tensor::InsertOp::create(builder, loc, v, val, crds);
1123 }
1124
1125 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1126 if (isSparse())
1127 return LoadOp::create(builder, loc, val, true);
1128 return val;
1129 }
1130
1131 bool isSparse() const {
1132 return getSparseTensorEncoding(val.getType()) != nullptr;
1133 }
1134
1135 Value val;
1136};
1137
1138struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1140 LogicalResult matchAndRewrite(tensor::DimOp op,
1141 PatternRewriter &rewriter) const override {
1142 std::optional<int64_t> dim = op.getConstantIndex();
1143 auto stt = tryGetSparseTensorType(op.getSource());
1144 if (!dim || !stt || !stt->hasEncoding())
1145 return failure();
1146
1147 if (stt->isPermutation()) {
1148 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1149 toLvl(stt->getEncoding(), *dim));
1150 return success();
1151 }
1152
1153 // Non-permutation dim2lvl/lvl2dim maps.
1154 // Compute as follows:
1155 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1156 // Note that it is not the most efficient way (but a more general one) for
1157 // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1158 // computed simply by lvl_size * block_size.
1159 Location loc = op.getLoc();
1160 SmallVector<Value> maxLvlCrds;
1161 for (Level l = 0; l < stt->getLvlRank(); l++) {
1162 Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l);
1163 Value maxLvlCrd = arith::SubIOp::create(
1164 rewriter, loc, lvlSz,
1165 constantOne(rewriter, loc, rewriter.getIndexType()));
1166 maxLvlCrds.push_back(maxLvlCrd);
1167 }
1168
1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170 Value maxDimCrd = affine::AffineApplyOp::create(
1171 rewriter, op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172 maxLvlCrds);
1173
1174 Value dimSz = arith::AddIOp::create(
1175 rewriter, loc, maxDimCrd,
1176 constantOne(rewriter, loc, rewriter.getIndexType()));
1177 rewriter.replaceOp(op, dimSz);
1178 return success();
1179 }
1180};
1181
1182struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1184 LogicalResult matchAndRewrite(ConcatenateOp op,
1185 PatternRewriter &rewriter) const override {
1186 if (op.needsExtraSort())
1187 op.emitError("ConcatenateOp not staged");
1188
1189 const Location loc = op.getLoc();
1190 const auto dstTp = getSparseTensorType(op);
1191 const Dimension conDim = op.getDimension();
1192 SmallVector<Value> sizes;
1193 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1194
1195 // %t = concatenate %s1, %s2, %s3 {dim = 1}
1196 // ==>
1197 // if (isSparseDst)
1198 // if (allDense)
1199 // %tmp = bufferization.alloc_tensor dstTp
1200 // else
1201 // %tmp = bufferization.alloc_tensor : unordered COO
1202 // else
1203 // %tmp = memref.alloc : dense tensor
1204 // foreach in %s1 : insert d0, d1, %tmp
1205 // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1206 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1207
1208 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1209 Value offset = constantIndex(rewriter, loc, 0);
1210 Value iterArg = dstBuf.val;
1211
1212 ForeachOp foreachOp;
1213 for (Value input : op.getInputs()) {
1214 // Builds a for op for each input tensor to append new values into the
1215 // output tensor.
1216 foreachOp = ForeachOp::create(
1217 rewriter, loc, input, iterArg,
1218 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1219 ValueRange reduc) {
1220 SmallVector<Value> offDimCrd(dcvs);
1221 offDimCrd[conDim] =
1222 arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset);
1223
1224 // Enters foreach, updates the SSA chain.
1225 dstBuf.val = reduc.front();
1226 if (!dstTp.isAllDense()) {
1227 Value cond = genIsNonzero(builder, loc, v);
1228 auto ifOp =
1229 scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1230 /*else*/ true);
1231 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1232 scf::YieldOp::create(builder, loc, dstBuf.val);
1233
1234 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1235 dstBuf.insert(builder, loc, v, offDimCrd);
1236 scf::YieldOp::create(builder, loc, dstBuf.val);
1237
1238 // Exits the ifOp, update the sparse tensor SSA value.
1239 builder.setInsertionPointAfter(ifOp);
1240 dstBuf.val = ifOp.getResult(0);
1241 } else {
1242 dstBuf.insert(builder, loc, v, offDimCrd);
1243 }
1244 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1245 });
1246 // Accumulates the offset. Note that only static-shaped inputs are allowed
1247 // by concatenate op verifier, which saves us from computing the offset
1248 // dynamically.
1249 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1250 assert(ShapedType::isStatic(sz));
1251 offset = arith::AddIOp::create(rewriter, loc, offset,
1252 constantIndex(rewriter, loc, sz));
1253 iterArg = foreachOp.getResult(0);
1254 dstBuf.val = iterArg;
1255 }
1256
1257 dstBuf.val = iterArg;
1258 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1259 rewriter.replaceOp(op, ret);
1260 return success();
1261 }
1262};
1263
1264struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1266 LogicalResult matchAndRewrite(ConvertOp op,
1267 PatternRewriter &rewriter) const override {
1268 if (op.needsExtraSort())
1269 return op.emitError("ConvertOp not staged.");
1270
1271 // TODO: Maybe we want a different operation for this too.
1272 auto encDst = getSparseTensorEncoding(op.getType());
1273 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1274 if (encDst && encSrc && !encSrc.isSlice() &&
1275 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1276 // Trivial tensor conversion and simple element type conversion is handled
1277 // in codegen.
1278 return failure();
1279 }
1280
1281 Location loc = op.getLoc();
1282 Value src = op.getSource();
1283
1284 SparseTensorType srcStt = getSparseTensorType(op.getSource());
1285 SparseTensorType dstStt = getSparseTensorType(op.getDest());
1286
1287 bool fromSparseConst = false;
1288 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1289 if (isa<SparseElementsAttr>(constOp.getValue()))
1290 fromSparseConst = true;
1291
1292 const AffineMapAttr foreachOrder =
1293 (!dstStt.isIdentity() && fromSparseConst)
1294 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1295 : nullptr;
1296
1297 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1298
1299 SmallVector<Value> sizes;
1300 sizesFromSrc(rewriter, sizes, loc, src);
1301 ValueRange vs;
1302 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1303
1304 auto foreachOp = ForeachOp::create(
1305 rewriter, loc, src, dstBuf.val, foreachOrder,
1306 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1307 ValueRange reduc) {
1308 // Enters the loop, update the SSA value for insertion chain.
1309 dstBuf.val = reduc.front();
1310 if (!skipZeroCheck) {
1311 Value cond = genIsNonzero(builder, loc, v);
1312 auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1313 /*else*/ true);
1314 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1315 scf::YieldOp::create(builder, loc, dstBuf.val);
1316
1317 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1318 dstBuf.insert(builder, loc, v, dcvs);
1319 scf::YieldOp::create(builder, loc, dstBuf.val);
1320
1321 // Exits the ifOp, update the sparse tensor SSA value.
1322 builder.setInsertionPointAfter(ifOp);
1323 dstBuf.val = ifOp.getResult(0);
1324 } else {
1325 dstBuf.insert(builder, loc, v, dcvs);
1326 }
1327 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1328 });
1329
1330 rewriter.setInsertionPointAfter(foreachOp);
1331
1332 // Exits the for loop, links the SSA chain.
1333 dstBuf.val = foreachOp.getResult(0);
1334
1335 Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1336 rewriter.replaceOp(op, ret);
1337 return success();
1338 }
1339};
1340
1341struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1343 LogicalResult matchAndRewrite(CrdTranslateOp op,
1344 PatternRewriter &rewriter) const override {
1345 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1346 ? op.getEncoder().getDimToLvl()
1347 : op.getEncoder().getLvlToDim();
1348
1349 SmallVector<Value> outCrds;
1350 for (AffineExpr result : map.getResults()) {
1351 // TODO: we should probably expand the affine map to IR using our own
1352 // rules, since affine.apply assume signed value, while the cooridinates
1353 // we provided must always be signless.
1354 Value trans = affine::AffineApplyOp::create(
1355 rewriter, op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1356 op.getInCrds());
1357 outCrds.push_back(trans);
1358 }
1359 rewriter.replaceOp(op, outCrds);
1360 return success();
1361 }
1362};
1363
1364/// Sparse rewriting rule for the foreach operator.
1365struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1366public:
1368
1369 LogicalResult matchAndRewrite(ForeachOp op,
1370 PatternRewriter &rewriter) const override {
1371
1372 auto loc = op.getLoc();
1373 Value input = op.getTensor();
1374 SmallVector<Value> reduc = op.getInitArgs();
1375 const auto stt = getSparseTensorType(input);
1376 const Level lvlRank = stt.getLvlRank();
1377
1378 // Special-case: for each over a sparse constant uses its own rewriting
1379 // rule.
1380 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1381 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1382 return genForeachOnSparseConstant(op, rewriter, attr);
1383 }
1384 }
1385
1386 // Otherwise, use loop emitter to generate loops.
1387 const auto enc = stt.getEncoding();
1388
1389 // 1. Generates loop for the sparse input.
1390 LoopEmitter loopEmitter(
1391 ValueRange{input},
1392 StringAttr::get(getContext(), ForeachOp::getOperationName()));
1393 loopEmitter.initializeLoopEmit(rewriter, loc);
1394 for (Level l = 0; l < lvlRank; l++) {
1395 // TODO: provide utility function for loop sequences that only contains
1396 // one for loop?
1397 const SmallVector<TensorLevel, 1> tidLvls{
1398 loopEmitter.makeTensorLevel(0, l)};
1399 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1400 // Note that reduc will be taken care of by loop emitter and get updated
1401 // in place.
1402 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1403 reduc);
1404 }
1405
1406 SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1407 if (op.getOrder()) {
1408 // TODO: Support it so that we can do direct conversion from CSR->BSR.
1409 llvm_unreachable(
1410 "Level order not yet implemented on non-constant input tensors.");
1411 }
1412
1413 Value vals = loopEmitter.getValBuffer()[0];
1414 SmallVector<Value> pos = loopEmitter.getValPosits(0);
1415 // Loads the value from sparse tensor using position-index;
1416 // loads the value from dense tensor using coords.
1417 Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos)
1418 : memref::LoadOp::create(rewriter, loc, vals, lcvs);
1419
1420 // 2. Inline the block in the foreach operator.
1421 Block *srcBlock = op.getBody();
1422
1423 // Remap coordinates.
1424 SmallVector<Value> args =
1425 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1426
1427 // Remap value.
1428 args.push_back(val);
1429 // Remap reduction variables.
1430 args.append(reduc);
1431
1432 // Remove sparse_tensor.yield.
1433 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1434 rewriter.eraseOp(srcBlock->getTerminator());
1435
1436 Operation &last = rewriter.getBlock()->back();
1437 if (llvm::isa<scf::YieldOp>(last)) {
1438 // Because `scf.for` inserts an implicit yield op when there is no
1439 // reduction variable upon creation, we reset the insertion point such
1440 // that the block is inlined before *before* the yield op.
1441 rewriter.setInsertionPoint(&last);
1442 }
1443
1444 rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1445 rewriter.getInsertionPoint(), args);
1446 rewriter.setInsertionPointToEnd(rewriter.getBlock());
1447 for (Level l = 0; l < lvlRank; l++) {
1448 // Link the reduction chain. Note that loop emitter update the reducValue
1449 // in place.
1450 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1451 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1452 }
1453
1454 // Replace the foreach operator with the value returned by the outtermost
1455 // for loop.
1456 rewriter.replaceOp(op, reducValue);
1457 return success();
1458 }
1459};
1460
1461/// Sparse rewriting rule for the new operator.
1462struct NewRewriter : public OpRewritePattern<NewOp> {
1464 LogicalResult matchAndRewrite(NewOp op,
1465 PatternRewriter &rewriter) const override {
1466 Location loc = op.getLoc();
1467 auto stt = getSparseTensorType(op.getResult());
1468 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1469 return failure();
1470
1471 // Implement the NewOp as follows:
1472 // %orderedCoo = sparse_tensor.new %filename
1473 // %t = sparse_tensor.convert %orderedCoo
1474 // with enveloping reinterpreted_map ops for non-permutations.
1475 RankedTensorType dstTp = stt.getRankedTensorType();
1476 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1477 Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource());
1478 Value convert = cooTensor;
1479 auto enc = stt.getEncoding();
1480 if (!stt.isPermutation()) { // demap coo, demap dstTp
1481 auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1482 convert = ReinterpretMapOp::create(rewriter, loc, coo, convert);
1483 dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1484 }
1485 convert = ConvertOp::create(rewriter, loc, dstTp, convert);
1486 if (!stt.isPermutation()) // remap to original enc
1487 convert = ReinterpretMapOp::create(rewriter, loc, enc, convert);
1488 rewriter.replaceOp(op, convert);
1489
1490 // Release the temporary ordered COO tensor.
1491 rewriter.setInsertionPointAfterValue(convert);
1492 DeallocTensorOp::create(rewriter, loc, cooTensor);
1493
1494 return success();
1495 }
1496};
1497
1498/// Sparse rewriting rule for the out operator.
1499struct OutRewriter : public OpRewritePattern<OutOp> {
1501 LogicalResult matchAndRewrite(OutOp op,
1502 PatternRewriter &rewriter) const override {
1503 Location loc = op.getLoc();
1504 // Calculate NNZ.
1505 Value src = op.getTensor();
1506 Value nnz = NumberOfEntriesOp::create(rewriter, loc, src);
1507
1508 // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1509 const auto srcTp = getSparseTensorType(src);
1510 const Dimension dimRank = srcTp.getDimRank();
1511 Type indexTp = rewriter.getIndexType();
1512 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1513
1514 // Generate code to calculate dimension size values and store the values to
1515 // the buffer.
1516 SmallVector<Value> dims;
1517 sizesForTensor(rewriter, dims, loc, srcTp, src);
1518 for (Dimension d = 0; d < dimRank; d++) {
1519 memref::StoreOp::create(rewriter, loc, dims[d], dimSizes,
1520 constantIndex(rewriter, loc, d));
1521 }
1522
1523 // Create a sparse tensor writer and output meta data.
1524 Type opaqueTp = getOpaquePointerType(rewriter);
1525 Value writer =
1526 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1527 {op.getDest()}, EmitCInterface::Off)
1528 .getResult(0);
1529 Value rankValue = constantIndex(rewriter, loc, dimRank);
1530 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1531 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1532
1533 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1534 Type eltTp = srcTp.getElementType();
1535 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1537 Value value = genAllocaScalar(rewriter, loc, eltTp);
1538 ModuleOp module = op->getParentOfType<ModuleOp>();
1539
1540 // For each element in the source tensor, output the element.
1541 ForeachOp::create(
1542 rewriter, loc, src, ValueRange(),
1543 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1544 ValueRange reduc) {
1545 for (Dimension d = 0; d < dimRank; d++) {
1546 memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords,
1547 constantIndex(builder, loc, d));
1548 }
1549 memref::StoreOp::create(rewriter, loc, v, value);
1550 SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1551 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1552 EmitCInterface::On);
1553 func::CallOp::create(builder, loc, TypeRange(), fn, operands);
1554 sparse_tensor::YieldOp::create(builder, loc);
1555 });
1556
1557 // Release the writer.
1558 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1559 EmitCInterface::Off);
1560
1561 rewriter.eraseOp(op);
1562 return success();
1563 }
1564};
1565
1566} // namespace
1567
1568//===---------------------------------------------------------------------===//
1569// Methods that add patterns described in this file to a pattern list.
1570//===---------------------------------------------------------------------===//
1571
1573 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1574 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1575 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1576 patterns.getContext());
1577}
1578
1580 bool enableRT,
1581 bool enableConvert) {
1582 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1583 ReshapeRewriter<tensor::CollapseShapeOp>,
1584 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1585 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1586 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1587 patterns.getContext());
1588
1589 if (enableConvert)
1590 patterns.add<DirectConvertRewriter>(patterns.getContext());
1591 if (!enableRT)
1592 patterns.add<NewRewriter>(patterns.getContext());
1593}
1594
1596 // Run CrdTranslateRewriter later in the pipeline so that operation can be
1597 // folded before lowering to affine.apply
1598 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1599}
for(Operation *op :ops)
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Block * getBlock() const
Returns the current block of the builder.
Definition Builders.h:448
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
uint64_t Level
The type of level identifiers and level-ranks.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...