MLIR 22.0.0git
Sparsification.cpp
Go to the documentation of this file.
1//===- Sparsification.cpp - Implementation of sparsification --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements converting sparse tensor types to actual sparse code.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenEnv.h"
14#include "Utils/CodegenUtils.h"
15#include "Utils/LoopEmitter.h"
16
31
32#include <optional>
33
34using namespace mlir;
35using namespace mlir::sparse_tensor;
36
37//===----------------------------------------------------------------------===//
38// Sparsifier analysis methods.
39//===----------------------------------------------------------------------===//
40
41/// Returns true iff affine expression is invariant. Sets the
42/// parameter `isCurrentLoop` when expression just became invariant.
43static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
44 switch (a.getKind()) {
46 const LoopId i = cast<AffineDimExpr>(a).getPosition();
47 if (i + 1 == curr) {
48 isCurrentLoop = true;
49 return true; // becomes invariant at current loop
50 }
51 return i < curr; // invariant when already generated
52 }
55 auto binOp = cast<AffineBinaryOpExpr>(a);
56 return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
57 isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
58 }
59 default: {
60 assert(isa<AffineConstantExpr>(a));
61 return true;
62 }
63 }
64}
65
66/// Helper method to inspect affine expressions. Rejects cases where the
67/// same index is used more than once. Also rejects compound affine
68/// expressions in sparse dimensions.
69static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
70 LevelType lt, bool setLvlFormat = true) {
71 switch (a.getKind()) {
73 const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
74 if (!isUndefLT(merger.getLvlType(tid, idx)))
75 return false; // used more than once
76 if (setLvlFormat)
77 merger.setLevelAndType(tid, idx, lvl, lt);
78 return true;
79 }
83 assert(lt.hasDenseSemantic());
84 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
85 // We do not set dim level format for affine expression like d0 + d1 on
86 // either loop index at d0 or d1. We continue the recursion merely to
87 // check whether current affine is admissible or not.
88 return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
89 findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
90 }
91 // Falls through when it is a constant Affine
92 return true;
93 }
94 default:
95 return false;
96 }
97}
98
99/// Helper method to inspect affine expressions for index variable reduction
100/// based codegen. It finds the dependent index set for all tensor levels in the
101/// current expression we are generating.
102///
103/// For example, when handling A[i+j][j+k], we build the two way mapping in
104/// merger between (tensor, level) pairs and their dependent index variable set:
105/// A_0 <=> [i, j] and A_1 <=> [j, k]
106///
107/// It rejects cases (returns false)
108/// 1st, when the same index is used more than once, e.g., A[i+j][i]
109/// 2nd, when multiplication is used in the non-trivial index expression.
110/// 3rd, when a constant operand is used in the non-trivial index expression.
111///
112/// TODO: constant should be easy to handle.
113static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
114 AffineExpr a, LevelType lt, bool isSubExp = false,
115 int64_t coefficient = 1) {
116 switch (a.getKind()) {
118 // Only allow positive coefficients on AffineDimExpr.
119 if (coefficient <= 0)
120 return false;
121
122 const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
123 if (!isUndefLT(merger.getLvlType(tensor, idx)))
124 return false; // used more than once, e.g., A[i][i]
125
126 // TODO: Generalizes the following two cases. A[i] (with trivial index
127 // expression) can be treated as a special affine index expression. We do
128 // not necessarily need to differentiate them.
129 if (!isSubExp) {
130 assert(coefficient == 1);
131 merger.setLevelAndType(tensor, idx, lvl, lt);
132 }
133
134 if (isSubExp) {
135 // The current loops appears in more than one affine expressions on the
136 // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
137 // used twice.
138 if (merger.hasDependentLvl(idx, tensor)) {
139 // TODO: This can be supported by coiterate slices if the loop idx is
140 // appeared on affine index for different tensor, or take slice on
141 // multiple dimensions when it is on the same tensor.
142 // E.g.,
143 // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
144 // d0_1 = getNextSliceOffset t0 along lvl0
145 // d0_2 = getNextSliceOffset t1 along lvl0
146 // if d0_1 == d0_2 then d0 = d0_1 = d0_1
147 // else increase min(d0_1, d0_2).
148 return false;
149 }
150 merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
151 }
152 return true;
153 }
155 case AffineExprKind::Mul: {
156 // TODO: Support index expression like `2 * d0`, we now only support more
157 // complicated cases like `2 * d0 + d1`.
158 if (!isSubExp)
159 return false;
160
161 // TODO: Support Constant AffineExp for slice-based codegen
162 if (isa<AffineConstantExpr>(a))
163 llvm_unreachable("Not yet implemented");
164
165 auto binOp = cast<AffineBinaryOpExpr>(a);
166 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
167 if (isa<AffineConstantExpr>(rhs))
168 std::swap(lhs, rhs);
169 // Must be in form of `constant * d`.
170 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
171 int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
172 return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
173 }
174 case AffineExprKind::Add: {
175 auto binOp = cast<AffineBinaryOpExpr>(a);
176 return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
177 findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
178 }
179 default:
180 return false;
181 }
182}
183
184/// Gets the total number of compound affine expressions in the
185/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
186///
187/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
188///
189/// Returns 1 (because the first level is compressed and its corresponding
190/// indexing-expression is `d0 + d1`)
192 Value tensor) {
193 // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
194 // we can't use `getRankedTensorType`/`getSparseTensorType` here.
195 // However, we don't need to handle `StorageSpecifierType`, so we
196 // can use `SparseTensorType` once we guard against non-tensors.
197 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
198 if (!rtp)
199 return 0;
200 const SparseTensorType stt(rtp);
201
202 const Level lvlRank = stt.getLvlRank();
203 const auto exprs = map.getResults();
204 assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
205 "AffineMap does not have dimension-rank many results");
206 unsigned num = 0;
207 for (Level l = 0; l < lvlRank; l++) {
208 if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
209 num++;
210 }
211 return num;
212}
213
214/// Gets the total number of sparse levels with compound affine
215/// expressions, summed over all operands of the `GenericOp`.
216static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
217 unsigned num = 0;
218 for (OpOperand &t : op->getOpOperands())
219 num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
220 t.get());
221 return num;
222}
223
224// Returns true iff output has nontrivial affine indices.
225static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
226 OpOperand *out = op.getDpsInitOperand(0);
227 if (getSparseTensorType(out->get()).isAllDense())
228 return false;
229 return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
230 out->get());
231}
232
233/// Helper method to inspect sparse encodings in the tensor types.
234/// Fills the per-dimension sparsity information for all tensors.
235/// Returns true if the sparse annotations and affine subscript
236/// expressions of all tensors are admissible. Returns false if
237/// no annotations are found or inadmissible constructs occur.
238/// We currently support two different ways to handle non-trivial index
239/// expression on sparse tensors, and they accept different affine expressions.
240/// When using dependent index reducton-based approach, it currently only
241/// supports affine addition index expression.
242static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
243 bool annotated = false;
244 for (OpOperand &t : env.op()->getOpOperands()) {
245 const TensorId tid = env.makeTensorId(t.getOperandNumber());
246 const auto map = env.op().getMatchingIndexingMap(&t);
247 const auto enc = getSparseTensorEncoding(t.get().getType());
248 if (enc)
249 annotated = true;
250 const Level lvlRank = map.getNumResults();
251 assert(!enc || lvlRank == enc.getLvlRank());
252 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
253 // We only need to do index reduction if there is at least one
254 // non-trivial index expression on sparse levels. If all non-trivial
255 // index expression is on dense levels, we can efficiently rely on
256 // the random access to locate the element.
257 bool needIdxReduc =
258 enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
259 // If then current tensor being inspected requires affine index, it need
260 // to be sliced.
261 for (Level l = 0; l < lvlRank; l++) {
262 const AffineExpr a = map.getResult(l);
263 const LevelType lt = enc.getLvlType(l);
264 if (idxReducBased && needIdxReduc) {
265 if (!findDepIdxSet(env.merger(), tid, l, a, lt))
266 return false; // inadmissible affine expression
267 } else {
268 if (!findAffine(env.merger(), tid, l, a, lt))
269 return false; // inadmissible affine expression
270 }
271 }
272 }
273 return annotated;
274}
275
276//===----------------------------------------------------------------------===//
277// Sparsifier synthesis methods (statements and expressions).
278//===----------------------------------------------------------------------===//
279
280/// Local bufferization of all dense and sparse data structures.
281static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
282 linalg::GenericOp op = env.op();
283 Location loc = op.getLoc();
284 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
285
286 SmallVector<Range, 4> loopRange =
287 llvm::cast<linalg::LinalgOp>(op.getOperation())
288 .createLoopRanges(builder, loc);
289
291 builder, loc,
292 /// Generates buffer for the output tensor.
293 /// Note that all sparse kernels assume that when all elements are written
294 /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
295 /// to all zeroes and only nonzeroes values are computed and written out.
296 /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
297 /// for the updates and no assumption on the original contents of the
298 /// output buffer is necessary.
299 [&op](OpBuilder &builder, Location loc, Value memref,
300 Value tensor) -> Value {
301 // Must not be a sparse tensor.
302 assert(!getSparseTensorEncoding(tensor.getType()));
303 // Two output tensor references should point to the same object.
304 OpOperand *lhs = op.getDpsInitOperand(0);
305 assert(lhs->get() == tensor);
306 // An output tensor can simply materialize from the buffer of the tensor
307 // that appears in the outs() clause. For updates, this has the
308 // advantage that only the nonzero value are involved in the
309 // computation, keeping the operation O(nnz). In all other cases, we are
310 // forced to zero out the buffer to enforce the assumption above, which
311 // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
312 // O(nnz) for matrices).
313 // TODO: use better analysis to avoid zeroing out the buffer?
314 bool isInit = op.isInitTensor(lhs);
315 Value init = memref;
316 if (!isInit) {
317 Value zero = constantZero(builder, loc,
318 getElementTypeOrSelf(tensor.getType()));
319 linalg::FillOp::create(builder, loc, ValueRange{zero},
320 ValueRange{init});
321 }
322 return init;
323 },
324 [&loopRange](OpBuilder &b, Location loc, Level l) {
325 assert(l < loopRange.size());
326 return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
327 });
328}
329
330/// Generates index for load/store on sparse tensor.
332 const auto map = env.op().getMatchingIndexingMap(t);
333 const auto stt = getSparseTensorType(t->get());
334 const Level lvlRank = stt.getLvlRank();
335 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
336 const AffineExpr a = map.getResult(lvlRank - 1);
337 assert(a.getKind() == AffineExprKind::DimId);
338 const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
339 return env.getLoopVar(idx);
340}
341
342/// Generates subscript for load/store on a dense or sparse tensor.
345 const Location loc = env.op().getLoc();
346 const TensorId tid = env.makeTensorId(t->getOperandNumber());
347 const auto map = env.op().getMatchingIndexingMap(t);
348 const auto stt = getSparseTensorType(t->get());
349 if (stt.hasEncoding()) {
350 // For sparse tensors we only push the last-level's position onto `args`.
351 const auto pos = env.emitter().getValPosits(tid);
352 assert(!pos.empty());
353 args.append(pos);
354 // Simply returns the tensor to extract value using iterators.
356 return t->get();
357 } else {
358 // For dense tensors we push all level's coordinates onto `args`.
359 const Level lvlRank = stt.getLvlRank();
360 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
361 for (Level l = 0; l < lvlRank; l++) {
362 const auto lvlExpr = map.getResult(l);
363 const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
364 args.push_back(lvlCrd);
365 }
366 }
367 return env.emitter().getValBuffer()[tid];
368}
369
370/// Generates insertion code to implement dynamic tensor load.
372 OpOperand *t) {
373 linalg::GenericOp op = env.op();
374 Location loc = op.getLoc();
375 // Direct lexicographic coordinate order, tensor loads as zero.
376 if (!env.isExpand()) {
378 return constantZero(builder, loc, tp);
379 }
380 // Load from expanded access pattern.
381 Value index = genIndex(env, t);
382 return memref::LoadOp::create(builder, loc, env.getExpandValues(), index);
383}
384
385/// Generates insertion code to implement dynamic tensor load for reduction.
387 OpOperand *t) {
388 linalg::GenericOp op = env.op();
389 Location loc = op.getLoc();
390 Value identity = env.getCustomRedId();
391 // Direct lexicographic coordinate order, tensor loads as identity.
392 if (!env.isExpand())
393 return identity;
394 // Load from expanded access pattern if filled, identity otherwise.
395 Value values = env.getExpandValues();
396 Value filled = env.getExpandFilled();
397 Value index = genIndex(env, t);
398 Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
399 Value valAtIndex = memref::LoadOp::create(builder, loc, values, index);
400 return arith::SelectOp::create(builder, loc, isFilled, valAtIndex, identity);
401}
402
404 Value sparseOut, ValueRange ivs, Value v) {
405 scf::IfOp condInsert =
406 scf::IfOp::create(builder, loc, sparseOut.getType(), cond, true);
407 // True branch.
408 builder.setInsertionPointToStart(condInsert.thenBlock());
409 Value res = tensor::InsertOp::create(builder, loc, v, sparseOut, ivs);
410 scf::YieldOp::create(builder, loc, res);
411 // False branch.
412 builder.setInsertionPointToStart(condInsert.elseBlock());
413 scf::YieldOp::create(builder, loc, sparseOut);
414 // Value assignment.
415 builder.setInsertionPointAfter(condInsert);
416 return condInsert.getResult(0);
417}
418
419/// Generates insertion code to implement dynamic tensor store.
420static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
421 Value rhs) {
422 linalg::GenericOp op = env.op();
423 Location loc = op.getLoc();
424 // Direct insertion in lexicographic coordinate order.
425 if (!env.isExpand()) {
426 const LoopId numLoops = op.getRank(t);
427 // Retrieves the first `numLoop` induction variables.
428 SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
429 env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
430 Value chain = env.getInsertionChain();
431 if (env.isValidLexInsert()) {
432 // Generates runtime check for a valid lex during reduction,
433 // to avoid inserting the identity value for empty reductions.
434 // if (validLexInsert) then
435 // insert(rhs) into chain
436 // return updated chain
437 // else
438 // return unmodified chain
439 Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
440 chain, ivs, rhs);
441 env.updateInsertionChain(out);
442 } else {
443 Value sparseOut;
444 if (!hasAnySparseType(env.op().getInputs().getTypes())) {
445 // This is an all-dense -> sparse kernel, test rhs != 0 before
446 // insertion.
447 Value nz = genIsNonzero(builder, loc, rhs);
448 sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
449 } else {
450 sparseOut = tensor::InsertOp::create(builder, loc, rhs, chain, ivs);
451 }
452 // Generates regular insertion chain.
453 env.updateInsertionChain(sparseOut);
454 }
455 return;
456 }
457 // Generates insertion code along expanded access pattern.
458 // if (!expFilled[i]) then
459 // expFilled[i] = true
460 // expAdded[inserts++] = i
461 // endif
462 // values[i] = rhs
463 Value values = env.getExpandValues();
464 Value filled = env.getExpandFilled();
465 Value added = env.getExpandAdded();
466 Value count = env.getExpandCount();
467 Value index = genIndex(env, t);
468 Value fval = constantI1(builder, loc, false);
469 Value tval = constantI1(builder, loc, true);
470 // If statement.
471 Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
472 Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
473 isFilled, fval);
474 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIndexType(), cond,
475 /*else=*/true);
476 // True branch.
477 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
478 memref::StoreOp::create(builder, loc, tval, filled, index);
479 memref::StoreOp::create(builder, loc, index, added, count);
480 Value one = constantIndex(builder, loc, 1);
481 Value add = arith::AddIOp::create(builder, loc, count, one);
482 scf::YieldOp::create(builder, loc, add);
483 // False branch.
484 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
485 scf::YieldOp::create(builder, loc, count);
486 builder.setInsertionPointAfter(ifOp);
487 // Value assignment.
488 env.updateExpandCount(ifOp.getResult(0));
489 memref::StoreOp::create(builder, loc, rhs, values, index);
490}
491
492/// Generates a load on a dense or sparse tensor.
493static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
494 // Test if the load was hoisted to a higher loop nest.
495 Value val = env.exp(exp).val;
496 if (val)
497 return val;
498 // Get tensor operand.
499 linalg::GenericOp op = env.op();
500 Location loc = op.getLoc();
501 OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
502 // Fold binary-valued tensor into explicit value.
503 const auto stt = getSparseTensorType(t->get());
504 if (auto explVal = stt.getExplicitVal())
505 return genValFromAttr(builder, loc, explVal);
506 // Load during insertion.
507 if (env.isSparseOutput(t)) {
508 if (env.isCustomReduc())
509 return genInsertionLoadReduce(env, builder, t);
510 return genInsertionLoad(env, builder, t);
511 }
512
513 // Actual load.
515 Value ptr = genSubscript(env, builder, t, args);
516 if (llvm::isa<TensorType>(ptr.getType())) {
517 assert(env.options().sparseEmitStrategy ==
519 return ExtractValOp::create(builder, loc, ptr,
520 llvm::getSingleElement(args));
521 }
522 return memref::LoadOp::create(builder, loc, ptr, args);
523}
524
525/// Generates a store on a dense or sparse tensor.
526static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
527 Value rhs) {
528 // Only unary and binary are allowed to return an uninitialized rhs
529 // to indicate missing output. Or otherwise a custom reduction that
530 // received no value to accumulate.
531 if (!rhs) {
532 assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
533 env.exp(exp).kind == TensorExp::Kind::kBinary ||
534 env.exp(exp).kind == TensorExp::Kind::kReduce);
535 return;
536 }
537 // Test if this is a scalarized reduction.
538 if (env.isReduc()) {
539 env.updateReduc(rhs);
540 return;
541 }
542 // Regular store.
543 linalg::GenericOp op = env.op();
544 Location loc = op.getLoc();
545 OpOperand *t = op.getDpsInitOperand(0);
546 if (!env.isSparseOutput(t)) {
548 Value ptr = genSubscript(env, builder, t, args);
549 memref::StoreOp::create(builder, loc, rhs, ptr, args);
550 return;
551 }
552 // Store during sparse insertion.
553 if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
554 genInsertionStore(env, builder, t, rhs);
555 return;
556 }
557 // Select operation insertion.
558 Value chain = env.getInsertionChain();
559 scf::IfOp ifOp =
560 scf::IfOp::create(builder, loc, chain.getType(), rhs, /*else=*/true);
561 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
562 // Existing value was preserved to be used here.
563 assert(env.exp(exp).val);
564 Value v0 = env.exp(exp).val;
565 genInsertionStore(env, builder, t, v0);
566 env.merger().clearExprValue(exp);
567 // Yield modified insertion chain along true branch.
568 Value mchain = env.getInsertionChain();
569 scf::YieldOp::create(builder, op.getLoc(), mchain);
570 // Yield original insertion chain along false branch.
571 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
572 scf::YieldOp::create(builder, loc, chain);
573 // Done with if statement.
574 env.updateInsertionChain(ifOp->getResult(0));
575 builder.setInsertionPointAfter(ifOp);
576}
577
578/// Generates an invariant value.
579inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
580 return env.exp(exp).val;
581}
582
583/// Semi-ring branches are simply inlined by the sparsifier. Prior
584/// analysis has verified that all computations are "local" to the inlined
585/// branch or otherwise invariantly defined outside the loop nest, with the
586/// exception of index computations, which need to be relinked to actual
587/// inlined cloned code.
588static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
589 Value e) {
590 if (auto arg = dyn_cast<BlockArgument>(e)) {
591 // Direct arguments of the original linalg op must be converted
592 // into dense tensor loads. Note that we should not encounter
593 // anything else. This needs to be verified by semi-ring ops.
594 linalg::GenericOp op = env.op();
595 if (arg.getOwner()->getParentOp() == op) {
596 const TensorId tid = env.makeTensorId(arg.getArgNumber());
597 OpOperand *t = &op->getOpOperand(tid);
598 assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
600 Value ptr = genSubscript(env, rewriter, t, args);
601 return memref::LoadOp::create(rewriter, op.getLoc(), ptr, args);
602 }
603 } else if (Operation *def = e.getDefiningOp()) {
604 // Handle index computation.
605 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
606 return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
607 // When still defined in new body, recurse into operands.
608 if (def->getBlock() == block) {
609 rewriter.setInsertionPoint(def);
610 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
611 rewriter.modifyOpInPlace(def, [&]() {
612 def->setOperand(
613 i, relinkBranch(env, rewriter, block, def->getOperand(i)));
614 });
615 }
616 }
617 }
618 return e;
619}
620
621/// Recursively generates tensor expression.
622static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
624 return Value();
625
626 linalg::GenericOp op = env.op();
627 Location loc = op.getLoc();
628 const TensorExp &exp = env.exp(e);
629 const auto kind = exp.kind;
630 if (kind == TensorExp::Kind::kTensor)
631 return genTensorLoad(env, rewriter, e);
632 if (kind == TensorExp::Kind::kInvariant)
633 return genInvariantValue(env, e);
634 if (kind == TensorExp::Kind::kLoopVar)
635 return env.getLoopVar(exp.loop);
636
637 if (kind == TensorExp::Kind::kReduce)
638 env.startCustomReduc(e); // enter custom
639
640 // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
641 // based on the type of the other operand.
642 Value v0, v1;
645 v1 = genExp(env, rewriter, exp.children.e1);
646 v0 = constantZero(rewriter, loc, v1.getType());
649 v0 = genExp(env, rewriter, exp.children.e0);
650 v1 = constantZero(rewriter, loc, v0.getType());
651 } else {
652 v0 = genExp(env, rewriter, exp.children.e0);
653 v1 = genExp(env, rewriter, exp.children.e1);
654 }
655
656 Value ee;
657 if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
658 // custom reduce did not receive a value
659 } else {
660 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
661 if (ee &&
664 kind == TensorExp::Kind::kReduce ||
665 kind == TensorExp::Kind::kSelect)) {
666 OpBuilder::InsertionGuard guard(rewriter);
667 ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
668 }
669 }
670
671 if (kind == TensorExp::Kind::kReduce)
672 env.endCustomReduc(); // exit custom
673
674 if (kind == TensorExp::Kind::kSelect)
675 env.merger().setExprValue(e, v0); // Preserve value for later use.
676
677 return ee;
678}
679
680/// Hoists loop invariant tensor loads for which indices have been exhausted.
681static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
682 LoopId curr, bool isStart) {
684 return;
685 if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
686 // Inspect tensor indices.
687 linalg::GenericOp op = env.op();
688 OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
689 const auto map = op.getMatchingIndexingMap(&t);
690 const auto stt = getSparseTensorType(t.get());
691 const Level lvlRank = stt.getLvlRank();
692 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
693 bool isCurrentLoop = curr == 0; // for scalar tensors
694 for (Level l = 0; l < lvlRank; l++) {
695 const AffineExpr a = map.getResult(l);
696 if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
697 return; // still in play
698 }
699 // All exhausted at current level.
700 if (!isCurrentLoop)
701 return;
702 // Generate code for a scalarized reduction or invariant. Note that
703 // because custom reduction lhs may occur several times in the IR,
704 // we have a built-in safety for only initializing and wrapping-up
705 // the scalarized reduction once.
706 OpOperand *lhs = op.getDpsInitOperand(0);
707 if (lhs == &t) {
708 // Start or end a scalarized reduction.
709 if (isStart) {
710 if (env.isCustomReduc()) {
711 if (!env.isReduc())
712 env.startReduc(exp, env.getCustomRedId());
713 } else {
714 env.startReduc(exp, genTensorLoad(env, builder, exp));
715 }
716 if (env.hasSparseOutput())
718 constantI1(builder, env.op().getLoc(), false));
719 } else {
720 if (!env.isCustomReduc() || env.isReduc())
721 genTensorStore(env, builder, exp, env.endReduc());
722 if (env.hasSparseOutput())
723 env.endValidLexInsert();
724 }
725 } else {
726 // Start or end loop invariant hoisting of a tensor load.
727 if (isStart) {
728 env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
729 } else {
730 env.merger().clearExprValue(exp);
731 }
732 }
733 } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
734 env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
735 env.exp(exp).kind != TensorExp::Kind::kSynZero) {
736 // Traverse into the binary operations. Note that we only hoist
737 // tensor loads, since subsequent MLIR/LLVM passes know how to
738 // deal with all other kinds of derived loop invariants.
739 if (env.exp(exp).kind == TensorExp::Kind::kReduce)
740 env.startCustomReduc(exp); // enter custom
741 const ExprId e0 = env.exp(exp).children.e0;
742 const ExprId e1 = env.exp(exp).children.e1;
743 genInvariants(env, builder, e0, curr, isStart);
744 genInvariants(env, builder, e1, curr, isStart);
745 if (env.exp(exp).kind == TensorExp::Kind::kReduce)
746 env.endCustomReduc(); // exit custom
747 }
748}
749
750/// Generates an expanded access pattern in innermost dimension.
751static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
752 bool isStart) {
753 linalg::GenericOp op = env.op();
754 OpOperand *lhs = op.getDpsInitOperand(0);
755 if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
756 return; // not needed at current level
757 assert(!env.isReduc());
758 // Generate start or end of an expanded access pattern. Note that because
759 // an expansion does not rely on the ongoing contents of the sparse storage
760 // scheme, we can use the original tensor as incoming SSA value (which
761 // simplifies codegen a bit). If expansion on the actual contents is ever
762 // needed, we will need to use the SSA value in the insertion chain instead.
763 Value tensor = lhs->get();
764 Location loc = op.getLoc();
765 if (isStart) {
766 auto dynShape = {ShapedType::kDynamic};
767 Type etp = cast<ShapedType>(tensor.getType()).getElementType();
768 Type t1 = MemRefType::get(dynShape, etp);
769 Type t2 = MemRefType::get(dynShape, builder.getI1Type());
770 Type t3 = MemRefType::get(dynShape, builder.getIndexType());
771 Type t4 = builder.getIndexType();
772 auto r =
773 ExpandOp::create(builder, loc, TypeRange({t1, t2, t3, t4}), tensor);
774 assert(r.getNumResults() == 4);
775 env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
776 r.getResult(3));
777 } else {
779 for (LoopId i = 0; i < curr; i++)
780 indices.push_back(env.emitter().getLoopIV(i));
781 Value values = env.getExpandValues();
782 Value filled = env.getExpandFilled();
783 Value added = env.getExpandAdded();
784 Value count = env.getExpandCount();
785 Value chain = env.getInsertionChain();
786 Value compress = CompressOp::create(builder, loc, values, filled, added,
787 count, chain, indices);
788 env.updateInsertionChain(compress);
789 env.endExpand();
790 }
791}
792
793/// Returns parallelization strategy. Any implicit loop in the Linalg
794/// operation that is marked "parallel" is a candidate. Whether it is actually
795/// converted to a parallel operation depends on the requested strategy.
796static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
797 // Reject parallelization of sparse output.
798 if (env.hasSparseOutput())
799 return false;
800 // Parallel loops on tensor expansion can cause data races.
801 if (env.isExpand())
802 return false;
803 // Inspect strategy.
804 switch (env.options().parallelizationStrategy) {
806 return false;
808 return isOuter && !isSparse;
810 return isOuter;
812 return !isSparse;
814 return true;
815 }
816 llvm_unreachable("unexpected parallelization strategy");
817}
818
819/// Whether or not the current loop being generated should be parallized (if
820/// possible) according to the configuration.
821static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
822 ArrayRef<TensorLevel> tidLvls) {
823 linalg::GenericOp op = env.op();
824 auto iteratorTypes = op.getIteratorTypesArray();
825 bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
826 // Queries the LT based on the tensor and loop id, as requested by
827 // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
828 // should be consistent with the LT indexed by <TensorId, Level>.
829 const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
830 return lt.hasSparseSemantic();
831 });
832 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
833}
834
835/// Emit a loop to coiterate over the list of tensor levels. The generated loop
836/// can either be a for loop or while loop depending on whether there is at most
837/// one sparse level in the list.
839 ArrayRef<TensorLevel> tidLvls,
840 unsigned numCases, bool tryParallel,
841 bool needsUniv) {
842 Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
843 // Construct while-loop with a parameter for each index.
845 builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
846 needsUniv);
847 });
848 assert(loop);
849 return loop;
850}
851
852/// Generates a for-loop or a while-loop, depending on whether it implements
853/// singleton iteration or co-iteration over the given conjunction.
854static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
855 unsigned numCases, bool needsUniv,
856 ArrayRef<TensorLevel> tidLvls) {
857 bool tryParallel = shouldTryParallize(env, curr, tidLvls);
858 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
859 needsUniv);
860}
861
862/// Generates the induction structure for a while-loop.
863static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
864 bool needsUniv) {
865 Location loc = env.op().getLoc();
866 // Finalize each else branch of all if statements.
867 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
868 while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
869 builder.getInsertionBlock()->getParentOp())) {
870 // Break on IfOp for slicing filtering.
871 if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
872 StringAttr::get(ifOp->getContext(), "slice"))
873 break;
874
875 unsigned y = 0;
876 SmallVector<Value> yields;
877 if (env.isReduc()) {
878 yields.push_back(env.getReduc());
879 env.updateReduc(ifOp.getResult(y++));
880 if (env.isValidLexInsert()) {
881 yields.push_back(env.getValidLexInsert());
882 env.updateValidLexInsert(ifOp.getResult(y++));
883 }
884 }
885 if (env.isExpand()) {
886 yields.push_back(env.getExpandCount());
887 env.updateExpandCount(ifOp->getResult(y++));
888 }
889 if (env.getInsertionChain()) {
890 yields.push_back(env.getInsertionChain());
891 env.updateInsertionChain(ifOp->getResult(y++));
892 }
893 assert(y == yields.size());
894 scf::YieldOp::create(builder, loc, yields);
895 builder.setInsertionPointAfter(ifOp);
896 }
897 }
898 // No need to set the insertion point here as LoopEmitter keeps track of the
899 // basic block where scf::Yield should be inserted.
900}
901
902/// Generates a case region in the coiterate operation.
903static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
904 unsigned caseIdx, LatPointId allCase,
905 LatPointId curCase,
907 assert(allCase == curCase || env.merger().latGT(allCase, curCase));
908 const BitVector &allCaseBits = env.merger().lat(allCase).simple;
909 const BitVector &curCaseBits = env.merger().lat(curCase).simple;
910
911 /// Computes the subset of iterators that are valid in the current case being
912 /// generated.
913 I64BitSet caseBit(0);
914 for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
915 if (curCaseBits.test(set))
916 caseBit.set(idx);
917
918 env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
919 caseIdx, reduc);
920}
921
922/// Generates a single if-statement within a while-loop.
923static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
924 LatPointId p) {
925 Location loc = env.op().getLoc();
926 SmallVector<Type> types;
927 Value cond;
929 p, /*simple=*/true,
930 [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
931 bool isIdxRed) {
932 if (isIdxRed) {
933 // Since there is no 1:1 mapping from loop to level (multiple loops
934 // are required to resolve one level with non-trivial index
935 // expression), we need to reconstruct the tensor level types if this
936 // loop requires index reduction condition.
937 assert(lvl.has_value() && isUndefLT(lt));
938 auto stt = getSparseTensorType(env.op().getInputs()[tid]);
939 lt = stt.getLvlType(*lvl);
940 }
941 assert(curr == env.merger().loop(b));
942 Value clause;
943 if (lt.hasSparseSemantic()) {
944 assert(lvl.has_value());
945 const Value crd = env.emitter().getCoord(tid, *lvl);
946 const Value lvar = env.getLoopVar(curr);
947 clause = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
948 crd, lvar);
949 } else {
950 assert(lt.hasDenseSemantic() || isUndefLT(lt));
951 clause = constantI1(builder, loc, true);
952 }
953 cond =
954 cond ? arith::AndIOp::create(builder, loc, cond, clause) : clause;
955 });
956 if (env.isReduc()) {
957 types.push_back(env.getReduc().getType());
958 if (env.isValidLexInsert())
959 types.push_back(env.getValidLexInsert().getType());
960 }
961 if (env.isExpand())
962 types.push_back(builder.getIndexType());
963 if (env.getInsertionChain())
964 types.push_back(env.getInsertionChain().getType());
965 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
966 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
967 return ifOp;
968}
969
970/// Generates end of true branch of if-statement within a while-loop.
971static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
972 Value redInput, Value cntInput, Value insInput,
973 Value validIns) {
974 SmallVector<Value> operands;
975 if (env.isReduc()) {
976 operands.push_back(env.getReduc());
977 env.updateReduc(redInput);
978 if (env.isValidLexInsert()) {
979 // Any overlapping indices during a reduction creates a valid lex insert.
980 operands.push_back(constantI1(builder, env.op().getLoc(), true));
981 env.updateValidLexInsert(validIns);
982 }
983 }
984 if (env.isExpand()) {
985 operands.push_back(env.getExpandCount());
986 env.updateExpandCount(cntInput);
987 }
988 if (env.getInsertionChain()) {
989 operands.push_back(env.getInsertionChain());
990 env.updateInsertionChain(insInput);
991 }
992 if (!operands.empty())
993 scf::YieldOp::create(builder, env.op().getLoc(), operands);
994 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
995}
996
997//===----------------------------------------------------------------------===//
998// Sparsifier synthesis methods (loop sequence).
999//===----------------------------------------------------------------------===//
1000
1002 CodegenEnv &env, LatPointId li, LoopId curr,
1003 llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1004 const BitVector &simple = env.lat(li).simple;
1005 const TensorId outTid = env.merger().getOutTensorID();
1006 const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1007
1008 unsigned numloopCond = 0;
1009 bool hasNonUnique = false;
1011 li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1012 LevelType lt, bool isIdxReduc) {
1013 if (simple[b]) {
1014 if (isIdxReduc) {
1015 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1016 numloopCond++;
1017 return;
1018 }
1019 if (isUndefLT(lt)) {
1020 // An undefined lt in the lattices, we probably mean to
1021 // generate a dense loop according to the synthetic tensor (for
1022 // invariants and sparse output tensor).
1023 if (env.merger().getSynTensorID() == tid) {
1024 // Coiterating with an invariant
1025 // e.g., out = prod(in[i][j] op invariant);
1026 // or a broadcast
1027 // e.g., out[i][j] = in[i] (j is undef for input)
1028 //
1029 // The level of the synthetic tensor is the current loop depth;
1030 // the rank of the synthetic tensor equals to number of loops.
1031 assert(curr == env.getCurrentDepth());
1032 lvl = curr;
1033 } else if (!lvl) {
1034 // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1035 return;
1036 }
1037 }
1038 hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1039 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1040 numloopCond++;
1041 } else if (lt.hasDenseSemantic() || isIdxReduc) {
1042 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1043 } else {
1044 assert(isUndefLT(lt));
1045 linalg::GenericOp op = env.op();
1046 if (tid >= op.getNumDpsInputs())
1047 // We only handle affine expression on input tensors (for now).
1048 return;
1049 OpOperand *operand = &op->getOpOperand(tid);
1050 const auto stt = getSparseTensorType(operand->get());
1051 // Non-annotated dense tensors requires no special handling.
1052 if (!stt.hasEncoding())
1053 return;
1054
1055 ArrayRef<AffineExpr> affines =
1056 op.getMatchingIndexingMap(operand).getResults();
1057 const Level lvlRank = stt.getLvlRank();
1058 assert(affines.size() == static_cast<size_t>(lvlRank));
1059 for (Level l = 0; l < lvlRank; l++) {
1060 AffineExpr exp = affines[l];
1061 // Skip simple affine expression and non-dense levels (which
1062 // have their own filter loop).
1063 LevelType lt = stt.getLvlType(l);
1064 if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1065 continue;
1066
1067 // Constant affine expression are handled in genLoop.
1068 if (!isa<AffineConstantExpr>(exp)) {
1069 bool isCurrentLoop = false;
1070 assert(curr == env.getCurrentDepth());
1071 if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1072 isCurrentLoop) {
1073 // If the compound affine is invariant and we are right at the
1074 // level. We need to generate the address according to the
1075 // affine expression. This is also the best place we can do it
1076 // to avoid putting it inside inner loops.
1077 callback(env.makeTensorLevel(tid, l), exp);
1078 }
1079 }
1080 }
1081 }
1082 });
1083
1084 if (isDenseLT(env.lt(outTid, curr))) {
1085 auto stt = getSparseTensorType(env.op().getOutputs().front());
1086 // Note that we generate dense indices of the output tensor unconditionally,
1087 // since they may not appear in the lattice, but may be needed for
1088 // linearized env.
1089 // TODO: we should avoid introducing corner cases for all-dense sparse
1090 // tensors.
1091 if (stt.hasEncoding() && stt.isAllDense())
1092 callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1093 }
1094
1095 if (numloopCond == 0) {
1096 // Corner cases where the loop bound is defined by a *unused* operand, in
1097 // this case, we just generate a dense "fake" loop by iterating over the
1098 // synthetic tensor.
1099 callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1100 numloopCond++;
1101 }
1102 // If we just need to one loop conditions and the conditions is not imposed on
1103 // non-unique level, the loop can be generated by a for loop.
1104 // Or, if we are generating sparse-iterator-based loops, we always generate
1105 // `sparse_tensor.iterate` regardless whether the level is unique or not.
1106 return numloopCond == 1 &&
1107 (!hasNonUnique || env.options().sparseEmitStrategy ==
1109}
1110
1111/// Starts a loop sequence at given level. Returns true if
1112/// the universal loop index must be maintained at this level.
1113static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1114 LoopId curr, LatSetId lts) {
1115 assert(!env.getLoopVar(curr));
1116 // Emit invariants at this loop sequence level.
1117 genInvariants(env, builder, exp, curr, /*isStart=*/true);
1118 // Emit access pattern expansion for sparse tensor output.
1119 genExpand(env, builder, curr, /*isStart=*/true);
1120 // Emit further initialization at this loop sequence level.
1121 const LatPointId l0 = env.set(lts)[0];
1122
1124 getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1125 // TODO: remove this! The same tensor level might be added for multiple
1126 // times due to the special handling for all-dense "sparse" output tensor
1127 // (see L1038).
1128 if (llvm::is_contained(tidLvls, tl))
1129 return;
1130 tidLvls.emplace_back(tl);
1131 });
1132
1133 env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1134
1135 // Maintain the universal index only if it is actually
1136 // consumed by a subsequent lattice point.
1137 for (const LatPointId li : env.set(lts).drop_front())
1138 if (!env.merger().hasAnySparse(env.lat(li).simple))
1139 return true;
1140
1141 return false;
1142}
1143
1144// Generates dense affine address for encoding.
1146 OpBuilder &builder, TensorId tid,
1147 Level startLvl) {
1148 // TODO: Handle affine expression on output tensor.
1149 linalg::GenericOp op = env.op();
1150 assert(tid < op.getNumDpsInputs());
1151 OpOperand *input = op.getDpsInputOperands()[tid];
1152 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1153 const auto enc = getSparseTensorEncoding(input->get().getType());
1154 if (enc) {
1155 const Location loc = op.getLoc();
1156 const TensorId tid = env.makeTensorId(input->getOperandNumber());
1157 const Level lvlRank = enc.getLvlRank();
1158 assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1159 for (Level l = startLvl; l < lvlRank; l++) {
1160 AffineExpr lvlExpr = lvlExprs[l];
1161 if (enc.getLvlType(l).hasDenseSemantic() &&
1162 isa<AffineConstantExpr>(lvlExpr))
1164 builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1165 else
1166 return; // break on first non-dense non-constant level
1167 }
1168 }
1169}
1170
1171// We can generate address for constant affine expression before any loops
1172// starting from the first level as they do not depend on anything.
1173// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1174// levels can be determined before loops.
1176 RewriterBase &rewriter) {
1177 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1178 genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1179}
1180
1181/// Returns true if the lattice bit can be iterated by a for loop.
1183 CodegenEnv &env, LatPointId li, LoopId curr,
1185 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1186 return getAllTidLvlsInLatPoints(env, li, curr,
1187 [&](TensorLevel tl, AffineExpr exp) {
1188 if (exp)
1189 affineTidLvls.emplace_back(tl, exp);
1190 else
1191 tidLvls.emplace_back(tl);
1192 });
1193}
1194
1195/// Starts a single loop in current sequence.
1196static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1197 OpBuilder &builder, LoopId curr,
1198 LatPointId li, unsigned numCases,
1199 bool needsUniv) {
1200 // TODO: numCases only used when generating iterator-based loops. Cleanup
1201 // after fully migration.
1202 // The set of tensors + lvls to generate loops on
1204
1205 // The set of dense tensors with non-trivial affine expression that just
1206 // becomes invariant and the address are generated at the current level.
1208 bool isSingleCond =
1209 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1210
1211 // Emit the for/while-loop control.
1212 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1213 Location loc = env.op().getLoc();
1214 for (auto [tidLvl, exp] : affineTidLvls) {
1215 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1216 }
1217
1218 // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1219 // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1220 // on constant affines expression may now be determined.
1221 auto allTidLvls =
1222 llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1223 for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1224 if (tid != env.merger().getOutTensorID() &&
1225 tid != env.merger().getSynTensorID())
1226 genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1227 }
1228
1229 return std::make_pair(loop, isSingleCond);
1230}
1231
1232/// Ends a single loop in current sequence. Returns new values for needsUniv.
1233static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1234 LatPointId li, bool needsUniv, bool isSingleCond) {
1235 // Either a for-loop or a while-loop that iterates over a slice.
1236 if (isSingleCond) {
1237 // Any iteration creates a valid lex insert.
1238 if (env.isReduc() && env.isValidLexInsert())
1239 env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1240 } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1241 // End a while-loop.
1242 finalizeWhileOp(env, rewriter, needsUniv);
1243 } else {
1244 needsUniv = false;
1245 }
1247 env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1248 return std::nullopt;
1249 });
1250 return needsUniv;
1251}
1252
1253/// Ends a loop sequence at given level.
1254static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1255 unsigned at) {
1256 assert(!env.getLoopVar(at));
1257 env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1258 // Unmark bookkeeping of invariants and loop index.
1259 genInvariants(env, builder, exp, at, /*isStart=*/false);
1260 // Finalize access pattern expansion for sparse tensor output.
1261 genExpand(env, builder, at, /*isStart=*/false);
1262}
1263
1264/// Recursively generates code while computing iteration lattices in order
1265/// to manage the complexity of implementing co-iteration over unions
1266/// and intersections of sparse iterations spaces.
1267static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1268 LoopId curr) {
1269 assert(curr == env.getCurrentDepth());
1270
1271 // At each leaf, assign remaining tensor (sub)expression to output tensor.
1272 if (curr == env.getLoopNum()) {
1273 Value rhs = genExp(env, rewriter, exp);
1274 genTensorStore(env, rewriter, exp, rhs);
1275 return;
1276 }
1277
1278 // Construct iteration lattices for current loop index.
1279 const LatSetId lts =
1280 env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1281
1282 // Start a loop sequence.
1283 bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1284
1285 // When using sparse-iterator-based loops, we only need one loops, as
1286 // opposed to a loop sequence, to cover all the iterator spaces.
1287 const unsigned lsize = env.set(lts).size();
1288 if (env.generatingSparseIterator()) {
1289 // Get the largest lattice point and start a loop.
1290 const LatPointId li = env.set(lts)[0];
1291 auto [loop, isSingleCond] =
1292 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1293 assert(isSingleCond == llvm::isa<IterateOp>(loop));
1294 // We cannot change this to `for (const LatPointId li : env.set(lts))`
1295 // because the loop body causes data-movement which invalidates
1296 // the iterator.
1297 for (unsigned j = 0; j < lsize; j++) {
1298 const LatPointId lj = env.set(lts)[j];
1299 const ExprId ej = env.lat(lj).exp;
1300 // Recurse into body of each branch.
1301 if (!isSingleCond) {
1302 env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1303 genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1304 genStmt(env, rewriter, ej, curr + 1);
1305 // TODO: handle yield values.
1306 assert(reduc.empty() && "Not Implemented");
1307 sparse_tensor::YieldOp::create(rewriter, env.op().getLoc());
1308 return std::nullopt;
1309 });
1310 // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1311 } else {
1312 genStmt(env, rewriter, ej, curr + 1);
1313 }
1314 }
1315 // End a loop.
1316 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1317 } else {
1318 // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1319 for (unsigned i = 0; i < lsize; i++) {
1320 const LatPointId li = env.set(lts)[i];
1321 // Start a loop.
1322 auto [loop, isSingleCond] =
1323 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1324
1325 // Visit all lattices points with Li >= Lj to generate the
1326 // loop-body, possibly with if statements for coiteration.
1327 Value redInput = env.getReduc();
1328 Value cntInput = env.getExpandCount();
1329 Value insInput = env.getInsertionChain();
1330 Value validIns = env.getValidLexInsert();
1331 // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1332 // because the loop body causes data-movement which invalidates the
1333 // iterator.
1334 for (unsigned j = 0; j < lsize; j++) {
1335 const LatPointId lj = env.set(lts)[j];
1336 const ExprId ej = env.lat(lj).exp;
1337 if (li == lj || env.merger().latGT(li, lj)) {
1338 // Recurse into body of each branch.
1339 if (!isSingleCond) {
1340 scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1341 genStmt(env, rewriter, ej, curr + 1);
1342 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1343 } else {
1344 genStmt(env, rewriter, ej, curr + 1);
1345 }
1346 }
1347 }
1348
1349 // End a loop.
1350 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1351 }
1352 }
1353
1354 // End a loop sequence.
1355 endLoopSeq(env, rewriter, exp, curr);
1356 assert(curr == env.getCurrentDepth());
1357}
1358
1359/// Converts the result computed by the sparse kernel into the required form.
1360static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1361 linalg::GenericOp op = env.op();
1362 OpOperand *lhs = op.getDpsInitOperand(0);
1363 Value tensor = lhs->get();
1364 Type resType = tensor.getType();
1365 if (getSparseTensorEncoding(resType)) {
1366 // The sparse tensor rematerializes from the original sparse tensor's
1367 // underlying sparse storage format. For an insertion chain, the
1368 // tensor materializes from the chain with 'hasInserts' enabled.
1369 bool hasInserts = false;
1370 if (Value chain = env.getInsertionChain()) {
1371 hasInserts = true;
1372 tensor = chain;
1373 }
1374 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1375 } else {
1376 // To rematerialize an non-annotated tensor, simply load it
1377 // from the bufferized value.
1378 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1379 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1380 }
1381}
1382
1383//===----------------------------------------------------------------------===//
1384// Sparsifier rewriting methods.
1385//===----------------------------------------------------------------------===//
1386
1387namespace {
1388
1389/// Sparse rewriting rule for generic Lingalg operation.
1390struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1391public:
1392 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1393 : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1394
1395 LogicalResult matchAndRewrite(linalg::GenericOp op,
1396 PatternRewriter &rewriter) const override {
1397 // Only accept single output operations with pure tensor semantics.
1398 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1399 return failure();
1400
1401 // Only accept trivial affine indices.
1403 return failure();
1404
1405 // Only accept scheduled loops.
1406 if (!op->hasAttr("sorted")) {
1407 return rewriter.notifyMatchFailure(
1408 op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1409 "before sparsification.");
1410 }
1411
1412 // Must have been demapped as well if the generic op is sorted.
1414
1415 // Sets up a code generation environment.
1416 const unsigned numTensors = op->getNumOperands();
1417 const unsigned numLoops = op.getNumLoops();
1418 bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1419 // If we have indexing map like (d0) -> (0, d0), there might be more
1420 // levels then loops because of the constant index, that means we can not
1421 // use numLoops as the upper bound for ranks of all tensors.
1422 // TODO: Constant indices are currently not support on sparse tensor, but
1423 // are allowed in non-annotated dense tensor. Support it, it would be
1424 // required for sparse tensor slice rank reducing too.
1425 Level maxLvlRank = 0;
1426 for (auto operand : op.getOperands()) {
1427 if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1428 maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1429 }
1430 }
1431
1432 // Detects sparse annotations and translates the per-level sparsity
1433 // information for all tensors to loop indices in the kernel.
1434 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1435 if (!findSparseAnnotations(env, needIdxRed))
1436 return failure();
1437
1438 // Only standard reduction operations (add, sub, or, xor) that can be
1439 // sparsified by merely reducing the stored values are admissible. More
1440 // elaborate reduction operations (such as mul, and, min, max) would need
1441 // to know whether implicit zeros occur as well. They can still be
1442 // implemented with a custom reduction operation, accepted here as well.
1443 if (op.getNumReductionLoops() > 0) {
1444 Operation *yield = op.getRegion().front().getTerminator();
1445 assert(isa<linalg::YieldOp>(yield));
1446 Operation *redop = yield->getOperand(0).getDefiningOp();
1447 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1448 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1449 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1450 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1451 !isa<ReduceOp>(redop)) {
1452 return failure();
1453 }
1454 }
1455
1456 // Constructs the tensor expressions tree from `op`, returns failure if the
1457 // tree can not be built or the tensor expression is inadmissible.
1458 if (failed(env.initTensorExp()))
1459 return failure();
1460
1461 // Recursively generates code if admissible.
1462 env.startEmit(options.sparseEmitStrategy);
1463 genBuffers(env, rewriter);
1464 // TODO: Constant affine expression should be handled differently when using
1465 // slice-based codegen, it does not matter now because we already reject the
1466 // constant expression at an earlier stage.
1467 genInitConstantDenseAddress(env, rewriter);
1468 genStmt(env, rewriter, env.getExprId(), 0);
1469 genResult(env, rewriter);
1470 return success();
1471 }
1472
1473private:
1474 /// Options to control sparse code generation.
1475 SparsificationOptions options;
1476};
1477
1478} // namespace
1479
1480/// Populates the given patterns list with rewriting rules required for
1481/// the sparsification of linear algebra operations.
1484 patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1485}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr > > &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static Operation * genCoIteration(CodegenEnv &env, OpBuilder &builder, ArrayRef< TensorLevel > tidLvls, unsigned numCases, bool tryParallel, bool needsUniv)
Emit a loop to coiterate over the list of tensor levels.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)
Starts a single loop in current sequence.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)
Generates a case region in the coiterate operation.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
#define add(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IntegerType getI1Type()
Definition Builders.cpp:53
IndexType getIndexType()
Definition Builders.cpp:51
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
The code generation environment class aggregates a number of data structures that are needed during t...
Definition CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
ArrayRef< LatPointId > set(LatSetId s) const
Definition CodegenEnv.h:83
unsigned getCurrentDepth() const
Definition CodegenEnv.h:115
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition CodegenEnv.h:95
constexpr TensorId makeTensorId(unsigned t) const
Definition CodegenEnv.h:72
void startExpand(Value values, Value filled, Value added, Value count)
void updateInsertionChain(Value chain)
bool generatingSparseIterator() const
Definition CodegenEnv.h:52
linalg::GenericOp op() const
Definition CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
void startEmit(SparseEmitStrategy emitStrategy)
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition CodegenEnv.h:111
void updateExpandCount(Value count)
bool isSparseOutput(OpOperand *o) const
Definition CodegenEnv.h:133
constexpr LoopId makeLoopId(unsigned i) const
Definition CodegenEnv.h:75
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition CodegenEnv.h:107
const TensorExp & exp(ExprId e) const
Definition CodegenEnv.h:81
const SparsificationOptions & options() const
Definition CodegenEnv.h:51
LevelType lt(TensorId t, LoopId i) const
Definition CodegenEnv.h:84
const LatPoint & lat(LatPointId l) const
Definition CodegenEnv.h:82
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
I64BitSet & set(unsigned i)
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast<AffineDimExpr>...
static constexpr llvm::StringLiteral getLoopEmitterLoopAttrName()
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
const std::vector< Value > & getValBuffer() const
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
auto getLoopIVsRange() const
Get the range of values for all induction variables.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
A class to handle all iteration lattice operations.
Definition Merger.h:225
const LatPoint & lat(LatPointId p) const
Definition Merger.h:545
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition Merger.cpp:940
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition Merger.h:249
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition Merger.h:426
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition Merger.h:442
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition Merger.h:416
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition Merger.cpp:431
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition Merger.h:472
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition Merger.h:348
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition Merger.h:363
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition Merger.cpp:1618
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition Merger.h:481
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition Merger.h:30
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
unsigned LatPointId
LatPoint identifiers.
Definition Merger.h:52
unsigned ExprId
TensorExp identifiers.
Definition Merger.h:48
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition Merger.h:35
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
Definition Enums.h:412
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition Merger.h:44
bool isDenseLT(LevelType lt)
Definition Enums.h:413
unsigned LoopId
Loop identifiers.
Definition Merger.h:38
unsigned LatSetId
LatSet identifiers.
Definition Merger.h:57
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ Constant
Constant integer.
Definition AffineExpr.h:57
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options for the Sparsification pass.
Definition Passes.h:108
SparseEmitStrategy sparseEmitStrategy
Definition Passes.h:122
SparseParallelizationStrategy parallelizationStrategy
Definition Passes.h:121
ExprId exp
Identifier of the tensor expression.
Definition Merger.h:218
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition Merger.h:215
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition Merger.h:93
Kind kind
Tensor expression kind.
Definition Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.