MLIR 22.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensors.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenUtils.h"
14#include "Utils/LoopEmitter.h"
15
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/Support/LLVM.h"
32
33using namespace mlir;
34using namespace mlir::bufferization;
35using namespace mlir::linalg;
36using namespace mlir::sparse_tensor;
37
38//===---------------------------------------------------------------------===//
39// Helper methods for the actual rewriting rules.
40//===---------------------------------------------------------------------===//
41
42// Helper to detect a sparse tensor type operand.
43static bool isSparseTensor(Value v) {
44 auto enc = getSparseTensorEncoding(v.getType());
45 return enc && !llvm::all_of(enc.getLvlTypes(),
46 [](auto lt) { return lt == LevelFormat::Dense; });
47}
48static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
49
50// Helper method to find zero/uninitialized tensor materialization.
51static bool isMaterializing(OpOperand *op, bool isZero) {
52 Value val = op->get();
53 // Check allocation, with zero alloc when required.
54 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
55 Value copy = alloc.getCopy();
56 if (isZero)
58 return !copy;
59 }
60 // Check for empty tensor materialization.
61 if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
62 return !isZero;
63 // Last resort for zero alloc: the whole value is zero.
64 return isZero && isZeroIntegerOrFloat(val);
65}
66
67// Helper to detect sampling operation.
68static bool isSampling(GenericOp op) {
69 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
70 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
71 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
72 // Both scalar input arguments used exactly once.
73 Value s1 = op.getBlock()->getArgument(0);
74 Value s2 = op.getBlock()->getArgument(1);
75 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
76 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
77 }
78 }
79 return false;
80}
81
82// Helper to detect chain of multiplications that do not involve x.
83static bool isMulChain(Value val, Value x) {
84 if (auto arg = dyn_cast<BlockArgument>(val))
85 return arg != x;
86 if (auto *def = val.getDefiningOp()) {
87 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
88 return isMulChain(def->getOperand(0), x) &&
89 isMulChain(def->getOperand(1), x);
90 }
91 return false;
92}
93
94// Helper to detect x = x + <multiplications>.
95static bool isSumOfMul(GenericOp op) {
96 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
97 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
98 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
99 Value x = op.getBlock()->getArguments().back();
100 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
101 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
102 }
103 }
104 return false;
105}
106
107// Helper to detect direct yield of a zero value.
108static bool isZeroYield(GenericOp op) {
109 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
110 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
111 if (arg.getOwner()->getParentOp() == op) {
112 return isZeroIntegerOrFloat(op->getOperand(arg.getArgNumber()));
113 }
114 }
115 return isZeroIntegerOrFloat(yieldOp.getOperand(0));
116}
117
118/// Populates given sizes array from type (for static sizes) and from
119/// the tensor (for dynamic sizes).
121 Location loc, ShapedType stp, Value tensor) {
122 for (const auto &d : enumerate(stp.getShape())) {
123 Value dim;
124 if (d.value() == ShapedType::kDynamic)
125 dim = tensor::DimOp::create(builder, loc, tensor, d.index());
126 else
127 dim = constantIndex(builder, loc, d.value());
128 sizes.push_back(dim);
129 }
130}
131
132static RankedTensorType getBufferType(const SparseTensorType &stt,
133 bool needTmpCOO) {
134 return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
135 : stt.getRankedTensorType();
136}
137
138/// Collects the dynamic dimension sizes for `tp` with the assumption that
139/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
140/// sizes to dynSizes.
141static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
142 SmallVectorImpl<Value> &dynSizes) {
143 for (const auto &d : enumerate(tp.getShape())) {
144 if (d.value() == ShapedType::kDynamic)
145 dynSizes.push_back(sizes[d.index()]);
146 }
147}
148
149static LogicalResult genForeachOnSparseConstant(ForeachOp op,
150 RewriterBase &rewriter,
151 SparseElementsAttr attr) {
152 auto loc = op.getLoc();
153 SmallVector<Value> reduc = op.getInitArgs();
154
155 // Foreach on constant.
157 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
158 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
160 args.append(cvs.begin(), cvs.end());
161 args.push_back(v);
162 args.append(reduc);
163 // Clones the foreach op to get a copy of the loop body.
164 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
165 assert(args.size() == cloned.getBody()->getNumArguments());
166 Operation *yield = cloned.getBody()->getTerminator();
167 rewriter.inlineBlockBefore(cloned.getBody(), op, args);
168 // clean up
169 rewriter.eraseOp(cloned);
170 reduc = yield->getOperands();
171 rewriter.eraseOp(yield);
172 });
173
174 rewriter.replaceOp(op, reduc);
175 return success();
176}
177
178/// Populates the given sizes array for concatenation from types (for static
179/// sizes) and from the source tensors (for dynamic sizes).
180static void concatSizesFromInputs(OpBuilder &builder,
182 ShapedType dstTp, ValueRange srcs,
183 unsigned dim) {
184 auto dstShape = dstTp.getShape();
185 sizesFromSrc(builder, sizes, loc, srcs[0]);
186
187 // Sum up on the `dim` if the dimension is dynamic.
188 if (dstShape[dim] != ShapedType::kDynamic) {
189 // Faithfully take the static size.
190 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
191 } else {
192 // Else, compute the shape dynamically.
193 for (const auto &src : srcs.drop_front()) {
194 Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
195 // Sum up all the sizes.
196 sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz);
197 }
198 }
199}
200
201//===---------------------------------------------------------------------===//
202// The actual sparse tensor rewriting rules.
203//===---------------------------------------------------------------------===//
204
205namespace {
206
207/// TODO: move it to tensor dialect instead.
208///
209/// Fold `tensor.concat` and `tensor.extract_slice`
210///
211/// %concat = tensor.concat dim(2) %t0, %t1
212/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
213/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
214/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
215/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
216/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
217///
218/// Becomes
219///
220/// %extract0, %extract1 = %t0, %t1
221struct FuseExtractSliceWithConcat
222 : public OpRewritePattern<tensor::ExtractSliceOp> {
223 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
224
225 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
226 PatternRewriter &rewriter) const override {
227 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
228 if (!concatOp)
229 return failure();
230
231 Location loc = extractOp.getLoc();
232 int64_t dim = concatOp.getDim();
233 int64_t rank = extractOp.getResultType().getRank();
234
235 SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
236 SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
237
238 // Compute the partial sums for the slice offsets.
239 AffineExpr sum = rewriter.getAffineDimExpr(0);
240 SmallVector<AffineExpr> partialSums = {sum};
241 SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
242 for (auto [idx, input] :
243 llvm::enumerate(concatOp.getInputs().drop_back())) {
244 sum = sum + rewriter.getAffineDimExpr(idx + 1);
245 partialSums.push_back(sum);
246 offsetStrides.push_back(
247 rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
248 }
249 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
250 partialSums, rewriter.getContext());
251 SmallVector<OpFoldResult> dimOffsets =
253 rewriter, loc, partialSumMap, offsetStrides);
254
255 auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
256 for (auto [l, r] : llvm::zip(lhs, rhs)) {
257 std::optional<int64_t> staticVal = getConstantIntValue(l);
258 if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
259 return false;
260 }
261 return lhs.size() == rhs.size();
262 };
263
264 for (auto [i, input, offset] :
265 llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
266 SmallVector<OpFoldResult> srcSizes =
267 tensor::getMixedSizes(rewriter, loc, input);
268 srcOffsets[dim] = offset;
269
270 SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
271 SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
272 SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
273
274 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
275 allEqual(srcStrides, dstStrides)) {
276 Value operand = concatOp.getOperand(i);
277 if (operand.getType() == extractOp.getResultType())
278 rewriter.replaceOp(extractOp, operand);
279 break;
280 }
281 }
282
283 return success();
284 }
285};
286
287/// Rewriting rule that fuses sparse_tensor.convert into producer.
288struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
289public:
291
292 LogicalResult matchAndRewrite(ConvertOp op,
293 PatternRewriter &rewriter) const override {
294 auto producer = op.getSource().getDefiningOp<GenericOp>();
295 if (!producer || producer.getDpsInits().size() != 1 ||
296 !isMaterializing(producer.getDpsInitOperand(0), false) ||
297 !producer.getResult(0).hasOneUse()) {
298 return failure();
299 }
300 // Clone the materialization operation, but update the result to sparse.
301 rewriter.setInsertionPoint(producer);
302 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
303 Operation *cloned = rewriter.clone(*init);
304 cloned->getResult(0).setType(op.getResult().getType());
305
306 rewriter.modifyOpInPlace(producer, [&]() {
307 producer.getDpsInitsMutable().assign(cloned->getResults());
308 producer.getResult(0).setType(op.getResult().getType());
309 });
310
311 rewriter.replaceAllOpUsesWith(op, producer);
312 op->erase();
313
314 return success();
315 }
316};
317
318/// Rewriting rule that converts direct yield of zero with initial allocation.
319struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
320public:
321 using OpRewritePattern<GenericOp>::OpRewritePattern;
322
323 LogicalResult matchAndRewrite(GenericOp op,
324 PatternRewriter &rewriter) const override {
325 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
326 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
327 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
328 return failure();
329 auto outputType = getRankedTensorType(op.getResult(0));
330 // Yielding zero on newly materialized sparse tensor can be
331 // optimized directly (regardless of dynamic or static size).
332 if (getSparseTensorEncoding(outputType)) {
333 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
334 return success();
335 }
336 // Use static zero value directly instead of materialization.
337 if (!outputType.hasStaticShape())
338 return failure();
339 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
340 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
341 rewriter.eraseOp(def);
342 return success();
343 }
344};
345
346/// Rewriting rule that converts two kernels:
347///
348/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
349/// X(i,j) = S(i,j) * T(i,j)
350///
351/// into a single kernel, using distributive law:
352///
353/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
354///
355/// This kind of fusion (merging two ops into one but using arithmetic
356/// equalities that may not hold for floating-point computations) would
357/// be undesirable in the dense case, since we distribute the multiplication
358/// into the reduction loop. However, for sparse sampling tensor S, such
359/// a fusion may actually reduce the asymptotic complexity of the kernel,
360/// since intermediate results may be nullified.
361struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
362public:
363 using OpRewritePattern<GenericOp>::OpRewritePattern;
364
365 LogicalResult matchAndRewrite(GenericOp op,
366 PatternRewriter &rewriter) const override {
367 // Check consumer.
368 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
369 op.getNumResults() != 1 ||
370 op.getNumParallelLoops() != op.getNumLoops() ||
371 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
372 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
373 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
374 return failure();
375 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
376 // operand can be sparse or dense, since the point of this rewriting rule
377 // is detecting a situation in which *more* sparsity is introduced into
378 // a computation, be it already sparse or still dense.
379 unsigned other = 0;
380 if (isSparseTensor(op.getDpsInputOperand(0)))
381 other = 1;
382 else if (!isSparseTensor(op.getDpsInputOperand(1)))
383 return failure();
384 // Check producer.
385 auto prod = dyn_cast_or_null<GenericOp>(
386 op.getDpsInputOperand(other)->get().getDefiningOp());
387 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
388 !prod.getResult(0).hasOneUse())
389 return failure();
390 // Sampling consumer and sum of multiplication chain producer.
391 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
392 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
393 !isSampling(op) || !isSumOfMul(prod))
394 return failure();
395 // Modify operand structure of producer and consumer.
396 Location loc = prod.getLoc();
397 SmallVector<Value> inputOps = prod.getInputs();
398 SmallVector<Value> outputOps = op.getOutputs();
399 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
400 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
401 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
402 // Fuse producer and consumer into a new generic op.
403 auto fusedOp = GenericOp::create(
404 rewriter, loc, op.getResult(0).getType(), inputOps, outputOps,
405 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
406 /*doc=*/nullptr, /*library_call=*/nullptr);
407 Block &prodBlock = prod.getRegion().front();
408 Block &consBlock = op.getRegion().front();
409 IRMapping mapper;
410 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
411 unsigned num = prodBlock.getNumArguments();
412 for (unsigned i = 0; i < num - 1; i++)
413 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
414 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
415 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
416 // Clone bodies of the producer and consumer in new evaluation order.
417 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
418 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
419 Value last;
420 for (auto &op : prodBlock.without_terminator())
421 if (&op != acc) {
422 last = op.getResult(0);
423 rewriter.clone(op, mapper);
424 }
425 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
426 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
427 last = rewriter.clone(*acc, mapper)->getResult(0);
428 linalg::YieldOp::create(rewriter, loc, last);
429 // Force initial value on merged allocation for dense outputs.
430 // TODO: deal with non alloc tensor here one day
431 if (!getSparseTensorEncoding(op.getResult(0).getType())) {
432 Value init = prod.getDpsInitOperand(0)
433 ->get()
434 .getDefiningOp<AllocTensorOp>()
435 .getCopy();
436 AllocTensorOp a =
437 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
438 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
439 }
440 // Replace consumer with fused operation. Old producer
441 // and consumer ops will be removed by DCE.
442 rewriter.replaceOp(op, fusedOp->getResults());
443 return success();
444 }
445
446private:
447 // Helper to add argument and record the mapping.
448 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
449 mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
450 }
451};
452
453// Fuse a tensor cast into producing operation. Note that a tensor.cast
454// should really not be used to convert between sparse encodings. Since
455// the pattern currently appears as a result of some prior rewriting
456// we make an attempt to repair very obvious cases.
457// TODO: audit the pure tensor dialect rewriting rules
458struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
459public:
460 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
461
462 LogicalResult matchAndRewrite(tensor::CastOp op,
463 PatternRewriter &rewriter) const override {
464 Type srcType = op.getSource().getType();
465 Type dstType = op.getDest().getType();
466 // A nop cast simply folds away.
467 if (srcType == dstType) {
468 rewriter.replaceOp(op, op->getResults());
469 return success();
470 }
471 // See if a sparsity changing cast can be fused into producer.
472 if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
473 if (Operation *def = op.getSource().getDefiningOp()) {
474 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
475 rewriter.modifyOpInPlace(def, [&]() {
476 def->getResult(0).setType(op->getResultTypes()[0]);
477 });
478 rewriter.replaceOp(op, def->getResult(0));
479 return success();
480 }
481 }
482 }
483 // Repair tensor casts with at least one sparse operand into the
484 // the properly supported sparse_tensor.convert.
485 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
486 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
487 return success();
488 }
489 // Fail otherwise.
490 return failure();
491 }
492};
493
494/// Rewrites a sequence of operations for sparse tensor selections in to
495/// semi-ring operations such that they can be compiled correctly by the
496/// sparsifier. E.g., transforming the following sequence
497///
498/// %sel = arith.select %cond, %sp1, %sp2
499///
500/// to
501///
502/// %sel = binary %sp1, %sp2:
503/// both (%l, %r) {yield select %cond, %l, %r}
504/// left (%l) {yield select %cond, %l, 0}
505/// right (%r) {yield select %cond, 0, %r}
506///
507/// TODO: We require that the tensor used for extracting conditions to be dense
508/// to sparsify the code. To support a sparse condition tensor, we need a
509/// tri-nary operation.
510struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
511public:
512 using OpRewritePattern<GenericOp>::OpRewritePattern;
513 LogicalResult matchAndRewrite(GenericOp op,
514 PatternRewriter &rewriter) const override {
515 // Rejects non sparse kernels.
516 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
517 return failure();
518
519 Location loc = op.getLoc();
520 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
521 for (Operation &inst : *op.getBody()) {
522 // Matches pattern.
523 auto matched = isRewritablePattern(op, &inst);
524 if (!matched.has_value())
525 continue;
526
527 rewriter.setInsertionPoint(&inst);
528 auto [c, t, f] = matched.value();
529 assert(t.getType() == f.getType());
530 auto selTp = t.getType();
531 auto c0 = constantZero(rewriter, loc, selTp);
532 auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f);
533 // Initializes all the blocks.
534 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
535 {t.getLoc(), f.getLoc()});
536 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
537 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
538
539 for (auto *r : binOp.getRegions()) {
540 Block *b = &r->front();
541 rewriter.setInsertionPointToStart(b);
542
543 IRMapping irMap;
544 // Clones the cmp operations into the region to make the binary op
545 // admissible.
546 Value newC = c;
547 if (auto *def = c.getDefiningOp())
548 newC = rewriter.clone(*def, irMap)->getResult(0);
549
550 irMap.map(c, newC);
551 if (r == &binOp.getLeftRegion()) {
552 irMap.map(t, b->getArgument(0));
553 irMap.map(f, c0);
554 } else if (r == &binOp.getRightRegion()) {
555 irMap.map(t, c0);
556 irMap.map(f, b->getArgument(0));
557 } else {
558 irMap.map(t, b->getArgument(0));
559 irMap.map(f, b->getArgument(1));
560 }
561 auto y = rewriter.clone(inst, irMap)->getResult(0);
562 sparse_tensor::YieldOp::create(rewriter, loc, y);
563 }
564
565 // We successfully rewrited a operation. We can not do replacement here
566 // becuase it invalidate the iterator for the current loop to traverse
567 // the instructions.
568 semiRings.emplace_back(&inst, binOp);
569 }
570
571 // Finalizes the replacement.
572 for (auto [sel, semi] : semiRings)
573 rewriter.replaceOp(sel, semi->getResults());
574
575 return success(!semiRings.empty());
576 }
577
578private:
579 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
580 isRewritablePattern(GenericOp op, Operation *v) {
581 auto sel = dyn_cast<arith::SelectOp>(v);
582 if (!sel)
583 return std::nullopt;
584
585 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
586 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
587 // TODO: For simplicity, we only handle cases where both true/false value
588 // are directly loaded the input tensor. We can probably admit more cases
589 // in theory.
590 if (!tVal || !fVal)
591 return std::nullopt;
592
593 // Helper lambda to determine whether the value is loaded from a dense input
594 // or is a loop invariant.
595 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
596 if (auto bArg = dyn_cast<BlockArgument>(v);
597 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
598 return true;
599 // If the value is defined outside the loop, it is a loop invariant.
600 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
601 };
602
603 // If the condition value is load directly from a dense tensor or
604 // loop-invariants, we can sparsify the kernel.
605 auto cond = sel.getCondition();
606 if (isValFromDenseInputOrInvariant(cond))
607 return std::make_tuple(cond, tVal, fVal);
608
609 Value cmpL, cmpR;
611 matchers::m_Any(&cmpR))) ||
613 matchers::m_Any(&cmpR)))) {
614 // TODO: we can do it recursively to check whether all the leaf values are
615 // loaded from dense tensors or are loop invariants.
616 if (isValFromDenseInputOrInvariant(cmpL) ||
617 isValFromDenseInputOrInvariant(cmpR))
618 return std::make_tuple(cond, tVal, fVal);
619 }
620
621 return std::nullopt;
622 };
623};
624
625/// Rewrites a sparse reduction that would not sparsify directly since
626/// doing so would only iterate over the stored elements, ignoring the
627/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
628/// (note that reductions like add/sub/or/xor can directly be sparsified
629/// since the implicit zeros do not contribute to the final result).
630/// Note that prod/and are still included since, even though they often
631/// are nullified in sparse data, they may still occur for special
632/// situations in which e.g. some rows in a sparse matrix are fully
633/// dense. For min/max, including the implicit zeros is a much more
634/// common situation.
635///
636/// TODO: this essentially "densifies" the operation; we want to implement
637/// this much more efficiently by performing the reduction over the
638/// stored values, and feed in the zero once if there were *any*
639/// implicit zeros as well; but for now, at least we provide
640/// the functionality
641///
642struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
643public:
644 using OpRewritePattern<GenericOp>::OpRewritePattern;
645
646 LogicalResult matchAndRewrite(GenericOp op,
647 PatternRewriter &rewriter) const override {
648 // Reject non-reductions.
649 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
650 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
651 return failure();
652 auto *inp = op.getDpsInputOperand(0);
653 auto *init = op.getDpsInitOperand(0);
654 if (!isSparseTensor(inp))
655 return failure();
656 // Look for direct x = x OP y for semi-ring ready reductions.
657 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
658 .getOperand(0)
659 .getDefiningOp();
660 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
661 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
662 arith::MaxUIOp>(red))
663 return failure();
664 Value s0 = op.getBlock()->getArgument(0);
665 Value s1 = op.getBlock()->getArgument(1);
666 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
667 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
668 return failure();
669 // Identity.
670 Location loc = op.getLoc();
671 Value identity =
672 tensor::ExtractOp::create(rewriter, loc, init->get(), ValueRange());
673 // Unary {
674 // present -> value
675 // absent -> zero.
676 // }
677 Type rtp = s0.getType();
678 rewriter.setInsertionPointToStart(&op.getRegion().front());
679 auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0);
680 Block *present =
681 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
682 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
683 sparse_tensor::YieldOp::create(rewriter, loc, present->getArgument(0));
684 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
685 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
686 auto zero =
687 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(rtp));
688 sparse_tensor::YieldOp::create(rewriter, loc, zero);
689 rewriter.setInsertionPointAfter(semiring);
690 // CustomReduce {
691 // x = x REDUC y, identity
692 // }
693 auto custom = sparse_tensor::ReduceOp::create(
694 rewriter, loc, rtp, semiring.getResult(), s1, identity);
695 Block *region =
696 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
697 rewriter.setInsertionPointToStart(&custom.getRegion().front());
698 IRMapping irMap;
699 irMap.map(red->getOperand(0), region->getArgument(0));
700 irMap.map(red->getOperand(1), region->getArgument(1));
701 auto *cloned = rewriter.clone(*red, irMap);
702 sparse_tensor::YieldOp::create(rewriter, loc, cloned->getResult(0));
703 rewriter.setInsertionPointAfter(custom);
704 rewriter.replaceOp(red, custom.getResult());
705 return success();
706 }
707};
708
709/// Sparse rewriting rule for the print operator. This operation is mainly used
710/// for debugging and testing. As such, it lowers to the vector.print operation
711/// which only require very light-weight runtime support.
712struct PrintRewriter : public OpRewritePattern<PrintOp> {
713public:
715 LogicalResult matchAndRewrite(PrintOp op,
716 PatternRewriter &rewriter) const override {
717 Location loc = op.getLoc();
718 auto tensor = op.getTensor();
719 auto stt = getSparseTensorType(tensor);
720 // Header with NSE.
721 auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor);
722 vector::PrintOp::create(
723 rewriter, loc,
724 rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
725 vector::PrintOp::create(rewriter, loc, nse);
726 // Print run-time contents for dim/lvl sizes.
727 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("dim = "));
728 printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
729 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("lvl = "));
730 printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
731 // Use the "codegen" foreach loop construct to iterate over
732 // all typical sparse tensor components for printing.
733 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
734 &stt](Type, FieldIndex,
736 Level l, LevelType) {
737 switch (kind) {
738 case SparseTensorFieldKind::StorageSpec: {
739 break;
740 }
741 case SparseTensorFieldKind::PosMemRef: {
742 auto lvl = constantIndex(rewriter, loc, l);
743 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("pos["));
744 vector::PrintOp::create(rewriter, loc, lvl,
745 vector::PrintPunctuation::NoPunctuation);
746 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : "));
747 auto pos = ToPositionsOp::create(rewriter, loc, tensor, l);
748 printContents(rewriter, loc, pos);
749 break;
750 }
751 case SparseTensorFieldKind::CrdMemRef: {
752 auto lvl = constantIndex(rewriter, loc, l);
753 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("crd["));
754 vector::PrintOp::create(rewriter, loc, lvl,
755 vector::PrintPunctuation::NoPunctuation);
756 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : "));
757 Value crd = nullptr;
758 // For COO AoS storage, we want to print a single, linear view of
759 // the full coordinate storage at this level. For any other storage,
760 // we show the coordinate storage for every indivual level.
761 if (stt.getAoSCOOStart() == l)
762 crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor);
763 else
764 crd = ToCoordinatesOp::create(rewriter, loc, tensor, l);
765 printContents(rewriter, loc, crd);
766 break;
767 }
768 case SparseTensorFieldKind::ValMemRef: {
769 vector::PrintOp::create(rewriter, loc,
770 rewriter.getStringAttr("values : "));
771 auto val = ToValuesOp::create(rewriter, loc, tensor);
772 printContents(rewriter, loc, val);
773 break;
774 }
775 }
776 return true;
777 });
778 vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("----\n"));
779 rewriter.eraseOp(op);
780 return success();
781 }
782
783private:
784 // Helper to print contents of a single memref. For "push_back" vectors,
785 // we assume that the previous getters for pos/crd/val have added a
786 // slice-to-size view to make sure we just print the size and not the
787 // full capacity.
788 //
789 // Generates code to print (1-dim or higher):
790 // ( a0, a1, ... )
791 static void printContents(PatternRewriter &rewriter, Location loc,
792 Value vec) {
793 auto shape = cast<ShapedType>(vec.getType()).getShape();
794 SmallVector<Value> idxs;
795 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
796 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
797 }
798
799 // Helper to the helper.
800 static void printContentsLevel(PatternRewriter &rewriter, Location loc,
801 Value vec, unsigned i, ArrayRef<int64_t> shape,
802 SmallVectorImpl<Value> &idxs) {
803 // Open bracket.
804 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
805 // Generate for loop.
806 auto zero = constantIndex(rewriter, loc, 0);
807 auto index = constantIndex(rewriter, loc, i);
808 auto size = memref::DimOp::create(rewriter, loc, vec, index);
809 auto step = constantIndex(rewriter, loc, 1);
810 auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step);
811 idxs.push_back(forOp.getInductionVar());
812 rewriter.setInsertionPointToStart(forOp.getBody());
813 if (i < shape.size() - 1) {
814 // Enter deeper loop nest.
815 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
816 } else {
817 // Actual contents printing.
818 auto val = memref::LoadOp::create(rewriter, loc, vec, idxs);
819 if (llvm::isa<ComplexType>(val.getType())) {
820 // Since the vector dialect does not support complex types in any op,
821 // we split those into (real, imag) pairs here.
822 Value real = complex::ReOp::create(rewriter, loc, val);
823 Value imag = complex::ImOp::create(rewriter, loc, val);
824 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
825 vector::PrintOp::create(rewriter, loc, real,
826 vector::PrintPunctuation::Comma);
827 vector::PrintOp::create(rewriter, loc, imag,
828 vector::PrintPunctuation::Close);
829 } else {
830 vector::PrintOp::create(rewriter, loc, val,
831 vector::PrintPunctuation::NoPunctuation);
832 }
833 // Terminating comma (except at end).
834 auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step);
835 Value cond = arith::CmpIOp::create(rewriter, loc,
836 arith::CmpIPredicate::ne, bound, size);
837 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond, /*else*/ false);
838 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
839 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma);
840 }
841 idxs.pop_back();
842 rewriter.setInsertionPointAfter(forOp);
843 // Close bracket.
844 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
845 }
846
847 // Helper method to print run-time lvl/dim sizes.
848 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
849 unsigned size, bool isDim) {
850 // Open bracket.
851 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
852 // Print unrolled contents (dimop requires constant value).
853 for (unsigned i = 0; i < size; i++) {
854 auto idx = constantIndex(rewriter, loc, i);
855 Value val;
856 if (isDim)
857 val = tensor::DimOp::create(rewriter, loc, tensor, idx);
858 else
859 val = LvlOp::create(rewriter, loc, tensor, idx);
860 vector::PrintOp::create(rewriter, loc, val,
861 i != size - 1
862 ? vector::PrintPunctuation::Comma
863 : vector::PrintPunctuation::NoPunctuation);
864 }
865 // Close bracket and end of line.
866 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
867 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
868 }
869};
870
871/// Sparse rewriting rule for sparse-to-sparse reshape operator.
872struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
873public:
874 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
875
876 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
877 PatternRewriter &rewriter) const override {
878 Location loc = op.getLoc();
879 Value srcTensor = op.getSource();
880 const auto srcTp = tryGetSparseTensorType(srcTensor);
881 const auto dstTp = tryGetSparseTensorType(op.getResult());
882 if (!srcTp || !dstTp)
883 return failure();
884
885 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
886 !dstTp->hasStaticDimShape())
887 return failure();
888
889 SmallVector<Value> srcSizes;
890 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
891 SmallVector<Value> dstSizes;
892 for (Dimension d : dstTp->getDimShape())
893 dstSizes.push_back(constantIndex(rewriter, loc, d));
894
895 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
896 // Only need an unordered COO buffer if input and output are not sorted
897 // in the same way.
898 Type bufferTp = getBufferType(
899 dstTp->withoutDimToLvl(),
900 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
901 SmallVector<Value> dynSizes;
902 Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes,
903 Value(), nnz, Attribute())
904 .getResult();
905
906 // Convert src coordinates to dst coordinates by first collapsing it to 1D
907 // and then expand it to the match the rank of the destination tensor.
908 // Implemented as follows:
909 // foreach srcCoords %srcTensor
910 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
911 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
912 // insert expandedCoords, %buffer
913 //
914 // followed by an optional
915 // %t = sparse_tensor.cast %tmp
916 // depending on whether the input/output are sorted in the same way.
917 const auto encSrc = srcTp->getEncoding();
918 ForeachOp foreachOp = ForeachOp::create(
919 rewriter, loc, srcTensor, buffer,
920 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
921 ValueRange reduc) {
922 const Dimension srcRank = srcTp->getDimRank();
923 SmallVector<Value> srcDcvs;
924 srcDcvs.reserve(srcRank);
925 for (Dimension d = 0; d < srcRank; d++) {
926 Level lvl = toLvl(encSrc, d);
927 srcDcvs.push_back(srcLcvs[lvl]);
928 }
929
930 Value collapseSize = constantIndex(builder, loc, 1);
931 for (Dimension d = 0; d < srcRank; d++)
932 collapseSize =
933 arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]);
934 SmallVector<Value, 1> collapsedSizes = {collapseSize};
935
936 ReassociationIndices collapseIdx;
937 for (Dimension i = 0; i < srcRank; i++)
938 collapseIdx.push_back(i);
939 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
940 SmallVector<Value, 1> collapsedDcvs;
941 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
942 collapsedSizes, collapsedDcvs);
943
944 ReassociationIndices expandIdx;
945 for (Dimension i = 0; i < dstTp->getDimRank(); i++)
946 expandIdx.push_back(i);
947 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
948 SmallVector<Value> dstDcvs;
949 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
950 dstSizes, dstDcvs);
951
952 auto t =
953 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
954 sparse_tensor::YieldOp::create(builder, loc, t);
955 });
956
957 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true);
958 if (bufferTp != *dstTp) {
959 auto dstRTT = dstTp->getRankedTensorType();
960 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
961 DeallocTensorOp::create(rewriter, loc, t);
962 t = converted;
963 }
964 rewriter.replaceOp(op, t);
965 return success();
966 }
967};
968
969/// Sparse rewriting rule for sparse-to-sparse reshape operator.
970template <typename ReshapeOp>
971struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
972public:
973 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
974
975 LogicalResult matchAndRewrite(ReshapeOp op,
976 PatternRewriter &rewriter) const override {
977 Location loc = op.getLoc();
978 Value srcTensor = op.getSrc();
979 const auto srcTp = getSparseTensorType(srcTensor);
980 const auto dstTp = getSparseTensorType(op.getResult());
981 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
982 return failure();
983
984 // Generate code to represent the static dimension constants or compute
985 // the dynamic dimension values.
986 SmallVector<Value> srcSizes;
987 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
988 SmallVector<Value> dstSizes;
989 SmallVector<Value> dstDynSizes;
990 if (dstTp.hasStaticDimShape()) {
991 for (Dimension d : dstTp.getDimShape())
992 dstSizes.push_back(constantIndex(rewriter, loc, d));
993 } else {
994 ArrayRef<Size> dstShape = dstTp.getDimShape();
995 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
996 op.getReassociationIndices());
997 for (auto [idx, shape] : llvm::enumerate(dstShape)) {
998 if (shape == ShapedType::kDynamic)
999 dstDynSizes.push_back(dstSizes[idx]);
1000 }
1001 }
1002 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
1003 // Only need a unordered COO buffer if input and output are not sorted
1004 // in the same way.
1005 Type bufferTp = getBufferType(
1006 dstTp.withoutDimToLvl(),
1007 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1008
1009 Value buffer =
1010 AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(),
1011 /*sizeHint=*/nnz, Attribute())
1012 .getResult();
1013
1014 // Implement the sparse2sparse reshape as follows:
1015 // foreach srcCoords %srcTensor
1016 // insert reshapeCvs(srcCoords), %buffer
1017 //
1018 // followed by an optional
1019 // %t = sparse_tensor.cast %tmp
1020 // depending on whether the input/output are sorted in the same way.
1021 const auto encSrc = srcTp.getEncoding();
1022 ForeachOp foreachOp = ForeachOp::create(
1023 rewriter, loc, srcTensor, buffer,
1024 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1025 ValueRange reduc) {
1026 const Dimension dimRank = srcTp.getDimRank();
1027 SmallVector<Value> srcDcvs;
1028 srcDcvs.reserve(dimRank);
1029 for (Dimension d = 0; d < dimRank; d++) {
1030 Level lvl = toLvl(encSrc, d);
1031 srcDcvs.push_back(srcLcvs[lvl]);
1032 }
1033 SmallVector<Value> dstDcvs;
1034 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1035 srcDcvs, dstSizes, dstDcvs);
1036 auto t =
1037 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
1038 sparse_tensor::YieldOp::create(builder, loc, t);
1039 });
1040
1041 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true);
1042 if (bufferTp != dstTp) {
1043 auto dstRTT = dstTp.getRankedTensorType();
1044 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
1045 DeallocTensorOp::create(rewriter, loc, t);
1046 t = converted;
1047 }
1048 rewriter.replaceOp(op, t);
1049 return success();
1050 }
1051};
1052
1053/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1054/// operator.
1055template <typename ReshapeOp>
1056struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1057public:
1058 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1059
1060 LogicalResult matchAndRewrite(ReshapeOp op,
1061 PatternRewriter &rewriter) const override {
1062 Location loc = op->getLoc();
1063 auto encDst = getSparseTensorEncoding(op.getResult().getType());
1064 auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1065 // Since a pure dense expansion is very cheap (change of view), for
1066 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1067 // conversion from the reshape operation itself.
1068 // All other cases are handled elsewhere.
1069 if (encDst && encSrc) {
1070 return failure();
1071 }
1072 if (encSrc) {
1073 auto rtp = getRankedTensorType(op.getSrc());
1074 auto denseTp =
1075 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1076 auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
1077 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1078 return success();
1079 }
1080 if (encDst) {
1081 auto rtp = getRankedTensorType(op.getResult());
1082 auto denseTp =
1083 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1084 ReshapeOp reshape;
1085 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1086 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1087 op.getReassociation(), op.getOutputShape(),
1088 op.getStaticOutputShape());
1089 } else {
1090 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1091 op.getReassociation());
1092 }
1093 Value convert = ConvertOp::create(rewriter, loc, rtp, reshape);
1094 rewriter.replaceOp(op, convert);
1095 return success();
1096 }
1097 return failure();
1098 }
1099};
1100
1101// A trivial wrapper to help generate different operations for dense/sparse
1102// tensors.
1103struct TensorLike {
1104 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1105 ValueRange sizes) {
1106 SmallVector<Value> dynSzs;
1107 getDynamicSizes(rtt, sizes, dynSzs);
1108
1109 val = AllocTensorOp::create(builder, loc, rtt, dynSzs);
1110 if (!isSparse()) {
1111 Value c0 = constantZero(builder, loc, rtt.getElementType());
1112 val = linalg::FillOp::create(builder, loc, c0, val).getResult(0);
1113 }
1114 }
1115
1116 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1117 val = tensor::InsertOp::create(builder, loc, v, val, crds);
1118 }
1119
1120 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1121 if (isSparse())
1122 return LoadOp::create(builder, loc, val, true);
1123 return val;
1124 }
1125
1126 bool isSparse() const {
1127 return getSparseTensorEncoding(val.getType()) != nullptr;
1128 }
1129
1130 Value val;
1131};
1132
1133struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1135 LogicalResult matchAndRewrite(tensor::DimOp op,
1136 PatternRewriter &rewriter) const override {
1137 std::optional<int64_t> dim = op.getConstantIndex();
1138 auto stt = tryGetSparseTensorType(op.getSource());
1139 if (!dim || !stt || !stt->hasEncoding())
1140 return failure();
1141
1142 if (stt->isPermutation()) {
1143 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1144 toLvl(stt->getEncoding(), *dim));
1145 return success();
1146 }
1147
1148 // Non-permutation dim2lvl/lvl2dim maps.
1149 // Compute as follows:
1150 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1151 // Note that it is not the most efficient way (but a more general one) for
1152 // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1153 // computed simply by lvl_size * block_size.
1154 Location loc = op.getLoc();
1155 SmallVector<Value> maxLvlCrds;
1156 for (Level l = 0; l < stt->getLvlRank(); l++) {
1157 Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l);
1158 Value maxLvlCrd = arith::SubIOp::create(
1159 rewriter, loc, lvlSz,
1160 constantOne(rewriter, loc, rewriter.getIndexType()));
1161 maxLvlCrds.push_back(maxLvlCrd);
1162 }
1163
1164 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1165 Value maxDimCrd = affine::AffineApplyOp::create(
1166 rewriter, op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1167 maxLvlCrds);
1168
1169 Value dimSz = arith::AddIOp::create(
1170 rewriter, loc, maxDimCrd,
1171 constantOne(rewriter, loc, rewriter.getIndexType()));
1172 rewriter.replaceOp(op, dimSz);
1173 return success();
1174 }
1175};
1176
1177struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1179 LogicalResult matchAndRewrite(ConcatenateOp op,
1180 PatternRewriter &rewriter) const override {
1181 if (op.needsExtraSort())
1182 op.emitError("ConcatenateOp not staged");
1183
1184 const Location loc = op.getLoc();
1185 const auto dstTp = getSparseTensorType(op);
1186 const Dimension conDim = op.getDimension();
1187 SmallVector<Value> sizes;
1188 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1189
1190 // %t = concatenate %s1, %s2, %s3 {dim = 1}
1191 // ==>
1192 // if (isSparseDst)
1193 // if (allDense)
1194 // %tmp = bufferization.alloc_tensor dstTp
1195 // else
1196 // %tmp = bufferization.alloc_tensor : unordered COO
1197 // else
1198 // %tmp = memref.alloc : dense tensor
1199 // foreach in %s1 : insert d0, d1, %tmp
1200 // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1201 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1202
1203 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1204 Value offset = constantIndex(rewriter, loc, 0);
1205 Value iterArg = dstBuf.val;
1206
1207 ForeachOp foreachOp;
1208 for (Value input : op.getInputs()) {
1209 // Builds a for op for each input tensor to append new values into the
1210 // output tensor.
1211 foreachOp = ForeachOp::create(
1212 rewriter, loc, input, iterArg,
1213 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1214 ValueRange reduc) {
1215 SmallVector<Value> offDimCrd(dcvs);
1216 offDimCrd[conDim] =
1217 arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset);
1218
1219 // Enters foreach, updates the SSA chain.
1220 dstBuf.val = reduc.front();
1221 if (!dstTp.isAllDense()) {
1222 Value cond = genIsNonzero(builder, loc, v);
1223 auto ifOp =
1224 scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1225 /*else*/ true);
1226 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1227 scf::YieldOp::create(builder, loc, dstBuf.val);
1228
1229 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1230 dstBuf.insert(builder, loc, v, offDimCrd);
1231 scf::YieldOp::create(builder, loc, dstBuf.val);
1232
1233 // Exits the ifOp, update the sparse tensor SSA value.
1234 builder.setInsertionPointAfter(ifOp);
1235 dstBuf.val = ifOp.getResult(0);
1236 } else {
1237 dstBuf.insert(builder, loc, v, offDimCrd);
1238 }
1239 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1240 });
1241 // Accumulates the offset. Note that only static-shaped inputs are allowed
1242 // by concatenate op verifier, which saves us from computing the offset
1243 // dynamically.
1244 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1245 assert(ShapedType::isStatic(sz));
1246 offset = arith::AddIOp::create(rewriter, loc, offset,
1247 constantIndex(rewriter, loc, sz));
1248 iterArg = foreachOp.getResult(0);
1249 dstBuf.val = iterArg;
1250 }
1251
1252 dstBuf.val = iterArg;
1253 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1254 rewriter.replaceOp(op, ret);
1255 return success();
1256 }
1257};
1258
1259struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1261 LogicalResult matchAndRewrite(ConvertOp op,
1262 PatternRewriter &rewriter) const override {
1263 if (op.needsExtraSort())
1264 return op.emitError("ConvertOp not staged.");
1265
1266 // TODO: Maybe we want a different operation for this too.
1267 auto encDst = getSparseTensorEncoding(op.getType());
1268 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1269 if (encDst && encSrc && !encSrc.isSlice() &&
1270 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1271 // Trivial tensor conversion and simple element type conversion is handled
1272 // in codegen.
1273 return failure();
1274 }
1275
1276 Location loc = op.getLoc();
1277 Value src = op.getSource();
1278
1279 SparseTensorType srcStt = getSparseTensorType(op.getSource());
1280 SparseTensorType dstStt = getSparseTensorType(op.getDest());
1281
1282 bool fromSparseConst = false;
1283 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1284 if (isa<SparseElementsAttr>(constOp.getValue()))
1285 fromSparseConst = true;
1286
1287 const AffineMapAttr foreachOrder =
1288 (!dstStt.isIdentity() && fromSparseConst)
1289 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1290 : nullptr;
1291
1292 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1293
1294 SmallVector<Value> sizes;
1295 sizesFromSrc(rewriter, sizes, loc, src);
1296 ValueRange vs;
1297 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1298
1299 auto foreachOp = ForeachOp::create(
1300 rewriter, loc, src, dstBuf.val, foreachOrder,
1301 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1302 ValueRange reduc) {
1303 // Enters the loop, update the SSA value for insertion chain.
1304 dstBuf.val = reduc.front();
1305 if (!skipZeroCheck) {
1306 Value cond = genIsNonzero(builder, loc, v);
1307 auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1308 /*else*/ true);
1309 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1310 scf::YieldOp::create(builder, loc, dstBuf.val);
1311
1312 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1313 dstBuf.insert(builder, loc, v, dcvs);
1314 scf::YieldOp::create(builder, loc, dstBuf.val);
1315
1316 // Exits the ifOp, update the sparse tensor SSA value.
1317 builder.setInsertionPointAfter(ifOp);
1318 dstBuf.val = ifOp.getResult(0);
1319 } else {
1320 dstBuf.insert(builder, loc, v, dcvs);
1321 }
1322 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1323 });
1324
1325 rewriter.setInsertionPointAfter(foreachOp);
1326
1327 // Exits the for loop, links the SSA chain.
1328 dstBuf.val = foreachOp.getResult(0);
1329
1330 Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1331 rewriter.replaceOp(op, ret);
1332 return success();
1333 }
1334};
1335
1336struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1338 LogicalResult matchAndRewrite(CrdTranslateOp op,
1339 PatternRewriter &rewriter) const override {
1340 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1341 ? op.getEncoder().getDimToLvl()
1342 : op.getEncoder().getLvlToDim();
1343
1344 SmallVector<Value> outCrds;
1345 for (AffineExpr result : map.getResults()) {
1346 // TODO: we should probably expand the affine map to IR using our own
1347 // rules, since affine.apply assume signed value, while the cooridinates
1348 // we provided must always be signless.
1349 Value trans = affine::AffineApplyOp::create(
1350 rewriter, op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1351 op.getInCrds());
1352 outCrds.push_back(trans);
1353 }
1354 rewriter.replaceOp(op, outCrds);
1355 return success();
1356 }
1357};
1358
1359/// Sparse rewriting rule for the foreach operator.
1360struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1361public:
1363
1364 LogicalResult matchAndRewrite(ForeachOp op,
1365 PatternRewriter &rewriter) const override {
1366
1367 auto loc = op.getLoc();
1368 Value input = op.getTensor();
1369 SmallVector<Value> reduc = op.getInitArgs();
1370 const auto stt = getSparseTensorType(input);
1371 const Level lvlRank = stt.getLvlRank();
1372
1373 // Special-case: for each over a sparse constant uses its own rewriting
1374 // rule.
1375 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1376 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1377 return genForeachOnSparseConstant(op, rewriter, attr);
1378 }
1379 }
1380
1381 // Otherwise, use loop emitter to generate loops.
1382 const auto enc = stt.getEncoding();
1383
1384 // 1. Generates loop for the sparse input.
1385 LoopEmitter loopEmitter(
1386 ValueRange{input},
1387 StringAttr::get(getContext(), ForeachOp::getOperationName()));
1388 loopEmitter.initializeLoopEmit(rewriter, loc);
1389 for (Level l = 0; l < lvlRank; l++) {
1390 // TODO: provide utility function for loop sequences that only contains
1391 // one for loop?
1392 const SmallVector<TensorLevel, 1> tidLvls{
1393 loopEmitter.makeTensorLevel(0, l)};
1394 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1395 // Note that reduc will be taken care of by loop emitter and get updated
1396 // in place.
1397 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1398 reduc);
1399 }
1400
1401 SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1402 if (op.getOrder()) {
1403 // TODO: Support it so that we can do direct conversion from CSR->BSR.
1404 llvm_unreachable(
1405 "Level order not yet implemented on non-constant input tensors.");
1406 }
1407
1408 Value vals = loopEmitter.getValBuffer()[0];
1409 SmallVector<Value> pos = loopEmitter.getValPosits(0);
1410 // Loads the value from sparse tensor using position-index;
1411 // loads the value from dense tensor using coords.
1412 Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos)
1413 : memref::LoadOp::create(rewriter, loc, vals, lcvs);
1414
1415 // 2. Inline the block in the foreach operator.
1416 Block *srcBlock = op.getBody();
1417
1418 // Remap coordinates.
1419 SmallVector<Value> args =
1420 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1421
1422 // Remap value.
1423 args.push_back(val);
1424 // Remap reduction variables.
1425 args.append(reduc);
1426
1427 // Remove sparse_tensor.yield.
1428 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1429 rewriter.eraseOp(srcBlock->getTerminator());
1430
1431 Operation &last = rewriter.getBlock()->back();
1432 if (llvm::isa<scf::YieldOp>(last)) {
1433 // Because `scf.for` inserts an implicit yield op when there is no
1434 // reduction variable upon creation, we reset the insertion point such
1435 // that the block is inlined before *before* the yield op.
1436 rewriter.setInsertionPoint(&last);
1437 }
1438
1439 rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1440 rewriter.getInsertionPoint(), args);
1441 rewriter.setInsertionPointToEnd(rewriter.getBlock());
1442 for (Level l = 0; l < lvlRank; l++) {
1443 // Link the reduction chain. Note that loop emitter update the reducValue
1444 // in place.
1445 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1446 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1447 }
1448
1449 // Replace the foreach operator with the value returned by the outtermost
1450 // for loop.
1451 rewriter.replaceOp(op, reducValue);
1452 return success();
1453 }
1454};
1455
1456/// Sparse rewriting rule for the new operator.
1457struct NewRewriter : public OpRewritePattern<NewOp> {
1459 LogicalResult matchAndRewrite(NewOp op,
1460 PatternRewriter &rewriter) const override {
1461 Location loc = op.getLoc();
1462 auto stt = getSparseTensorType(op.getResult());
1463 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1464 return failure();
1465
1466 // Implement the NewOp as follows:
1467 // %orderedCoo = sparse_tensor.new %filename
1468 // %t = sparse_tensor.convert %orderedCoo
1469 // with enveloping reinterpreted_map ops for non-permutations.
1470 RankedTensorType dstTp = stt.getRankedTensorType();
1471 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1472 Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource());
1473 Value convert = cooTensor;
1474 auto enc = stt.getEncoding();
1475 if (!stt.isPermutation()) { // demap coo, demap dstTp
1476 auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1477 convert = ReinterpretMapOp::create(rewriter, loc, coo, convert);
1478 dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1479 }
1480 convert = ConvertOp::create(rewriter, loc, dstTp, convert);
1481 if (!stt.isPermutation()) // remap to original enc
1482 convert = ReinterpretMapOp::create(rewriter, loc, enc, convert);
1483 rewriter.replaceOp(op, convert);
1484
1485 // Release the temporary ordered COO tensor.
1486 rewriter.setInsertionPointAfterValue(convert);
1487 DeallocTensorOp::create(rewriter, loc, cooTensor);
1488
1489 return success();
1490 }
1491};
1492
1493/// Sparse rewriting rule for the out operator.
1494struct OutRewriter : public OpRewritePattern<OutOp> {
1496 LogicalResult matchAndRewrite(OutOp op,
1497 PatternRewriter &rewriter) const override {
1498 Location loc = op.getLoc();
1499 // Calculate NNZ.
1500 Value src = op.getTensor();
1501 Value nnz = NumberOfEntriesOp::create(rewriter, loc, src);
1502
1503 // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1504 const auto srcTp = getSparseTensorType(src);
1505 const Dimension dimRank = srcTp.getDimRank();
1506 Type indexTp = rewriter.getIndexType();
1507 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1508
1509 // Generate code to calculate dimension size values and store the values to
1510 // the buffer.
1511 SmallVector<Value> dims;
1512 sizesForTensor(rewriter, dims, loc, srcTp, src);
1513 for (Dimension d = 0; d < dimRank; d++) {
1514 memref::StoreOp::create(rewriter, loc, dims[d], dimSizes,
1515 constantIndex(rewriter, loc, d));
1516 }
1517
1518 // Create a sparse tensor writer and output meta data.
1519 Type opaqueTp = getOpaquePointerType(rewriter);
1520 Value writer =
1521 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1522 {op.getDest()}, EmitCInterface::Off)
1523 .getResult(0);
1524 Value rankValue = constantIndex(rewriter, loc, dimRank);
1525 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1526 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1527
1528 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1529 Type eltTp = srcTp.getElementType();
1530 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1532 Value value = genAllocaScalar(rewriter, loc, eltTp);
1533 ModuleOp module = op->getParentOfType<ModuleOp>();
1534
1535 // For each element in the source tensor, output the element.
1536 ForeachOp::create(
1537 rewriter, loc, src, ValueRange(),
1538 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1539 ValueRange reduc) {
1540 for (Dimension d = 0; d < dimRank; d++) {
1541 memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords,
1542 constantIndex(builder, loc, d));
1543 }
1544 memref::StoreOp::create(rewriter, loc, v, value);
1545 SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1546 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1547 EmitCInterface::On);
1548 func::CallOp::create(builder, loc, TypeRange(), fn, operands);
1549 sparse_tensor::YieldOp::create(builder, loc);
1550 });
1551
1552 // Release the writer.
1553 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1554 EmitCInterface::Off);
1555
1556 rewriter.eraseOp(op);
1557 return success();
1558 }
1559};
1560
1561} // namespace
1562
1563//===---------------------------------------------------------------------===//
1564// Methods that add patterns described in this file to a pattern list.
1565//===---------------------------------------------------------------------===//
1566
1568 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1569 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1570 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1571 patterns.getContext());
1572}
1573
1575 bool enableRT,
1576 bool enableConvert) {
1577 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1578 ReshapeRewriter<tensor::CollapseShapeOp>,
1579 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1580 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1581 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1582 patterns.getContext());
1583
1584 if (enableConvert)
1585 patterns.add<DirectConvertRewriter>(patterns.getContext());
1586 if (!enableRT)
1587 patterns.add<NewRewriter>(patterns.getContext());
1588}
1589
1591 // Run CrdTranslateRewriter later in the pipeline so that operation can be
1592 // folded before lowering to affine.apply
1593 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1594}
for(Operation *op :ops)
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Block * getBlock() const
Returns the current block of the builder.
Definition Builders.h:448
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
uint64_t Level
The type of level identifiers and level-ranks.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...