MLIR 23.0.0git
Sparsification.cpp
Go to the documentation of this file.
1//===- Sparsification.cpp - Implementation of sparsification --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements converting sparse tensor types to actual sparse code.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenEnv.h"
14#include "Utils/CodegenUtils.h"
15#include "Utils/LoopEmitter.h"
16
31
32#include <optional>
33
34using namespace mlir;
35using namespace mlir::sparse_tensor;
36
37//===----------------------------------------------------------------------===//
38// Sparsifier analysis methods.
39//===----------------------------------------------------------------------===//
40
41/// Returns true iff affine expression is invariant. Sets the
42/// parameter `isCurrentLoop` when expression just became invariant.
43static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
44 switch (a.getKind()) {
46 const LoopId i = cast<AffineDimExpr>(a).getPosition();
47 if (i + 1 == curr) {
48 isCurrentLoop = true;
49 return true; // becomes invariant at current loop
50 }
51 return i < curr; // invariant when already generated
52 }
55 auto binOp = cast<AffineBinaryOpExpr>(a);
56 return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
57 isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
58 }
59 default: {
60 assert(isa<AffineConstantExpr>(a));
61 return true;
62 }
63 }
64}
65
66/// Helper method to inspect affine expressions. Rejects cases where the
67/// same index is used more than once. Also rejects compound affine
68/// expressions in sparse dimensions.
69static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
70 LevelType lt, bool setLvlFormat = true) {
71 switch (a.getKind()) {
73 const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
74 if (!isUndefLT(merger.getLvlType(tid, idx)))
75 return false; // used more than once
76 if (setLvlFormat)
77 merger.setLevelAndType(tid, idx, lvl, lt);
78 return true;
79 }
83 assert(lt.hasDenseSemantic());
84 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
85 // We do not set dim level format for affine expression like d0 + d1 on
86 // either loop index at d0 or d1. We continue the recursion merely to
87 // check whether current affine is admissible or not.
88 return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
89 findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
90 }
91 // Falls through when it is a constant Affine
92 return true;
93 }
94 default:
95 return false;
96 }
97}
98
99/// Helper method to inspect affine expressions for index variable reduction
100/// based codegen. It finds the dependent index set for all tensor levels in the
101/// current expression we are generating.
102///
103/// For example, when handling A[i+j][j+k], we build the two way mapping in
104/// merger between (tensor, level) pairs and their dependent index variable set:
105/// A_0 <=> [i, j] and A_1 <=> [j, k]
106///
107/// It rejects cases (returns false)
108/// 1st, when the same index is used more than once, e.g., A[i+j][i]
109/// 2nd, when multiplication is used in the non-trivial index expression.
110/// 3rd, when a constant operand is used in the non-trivial index expression.
111///
112/// TODO: constant should be easy to handle.
113static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
114 AffineExpr a, LevelType lt, bool isSubExp = false,
115 int64_t coefficient = 1) {
116 switch (a.getKind()) {
118 // Only allow positive coefficients on AffineDimExpr.
119 if (coefficient <= 0)
120 return false;
121
122 const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
123 if (!isUndefLT(merger.getLvlType(tensor, idx)))
124 return false; // used more than once, e.g., A[i][i]
125
126 // TODO: Generalizes the following two cases. A[i] (with trivial index
127 // expression) can be treated as a special affine index expression. We do
128 // not necessarily need to differentiate them.
129 if (!isSubExp) {
130 assert(coefficient == 1);
131 merger.setLevelAndType(tensor, idx, lvl, lt);
132 }
133
134 if (isSubExp) {
135 // The current loops appears in more than one affine expressions on the
136 // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
137 // used twice.
138 if (merger.hasDependentLvl(idx, tensor)) {
139 // TODO: This can be supported by coiterate slices if the loop idx is
140 // appeared on affine index for different tensor, or take slice on
141 // multiple dimensions when it is on the same tensor.
142 // E.g.,
143 // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
144 // d0_1 = getNextSliceOffset t0 along lvl0
145 // d0_2 = getNextSliceOffset t1 along lvl0
146 // if d0_1 == d0_2 then d0 = d0_1 = d0_1
147 // else increase min(d0_1, d0_2).
148 return false;
149 }
150 merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
151 }
152 return true;
153 }
155 case AffineExprKind::Mul: {
156 // TODO: Support index expression like `2 * d0`, we now only support more
157 // complicated cases like `2 * d0 + d1`.
158 if (!isSubExp)
159 return false;
160
161 // TODO: Support Constant AffineExp for slice-based codegen
162 if (isa<AffineConstantExpr>(a))
163 llvm_unreachable("Not yet implemented");
164
165 auto binOp = cast<AffineBinaryOpExpr>(a);
166 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
167 if (isa<AffineConstantExpr>(rhs))
168 std::swap(lhs, rhs);
169 // Must be in form of `constant * d`.
170 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
171 int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
172 return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
173 }
174 case AffineExprKind::Add: {
175 auto binOp = cast<AffineBinaryOpExpr>(a);
176 return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
177 findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
178 }
179 default:
180 return false;
181 }
182}
183
184/// Gets the total number of compound affine expressions in the
185/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
186///
187/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
188///
189/// Returns 1 (because the first level is compressed and its corresponding
190/// indexing-expression is `d0 + d1`)
192 Value tensor) {
193 // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
194 // we can't use `getRankedTensorType`/`getSparseTensorType` here.
195 // However, we don't need to handle `StorageSpecifierType`, so we
196 // can use `SparseTensorType` once we guard against non-tensors.
197 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
198 if (!rtp)
199 return 0;
200 const SparseTensorType stt(rtp);
201
202 const Level lvlRank = stt.getLvlRank();
203 const auto exprs = map.getResults();
204 assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
205 "AffineMap does not have dimension-rank many results");
206 unsigned num = 0;
207 for (Level l = 0; l < lvlRank; l++) {
208 if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
209 num++;
210 }
211 return num;
212}
213
214/// Gets the total number of sparse levels with compound affine
215/// expressions, summed over all operands of the `GenericOp`.
216static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
217 unsigned num = 0;
218 for (OpOperand &t : op->getOpOperands())
219 num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
220 t.get());
221 return num;
222}
223
224// Returns true iff output has nontrivial affine indices.
225static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
226 OpOperand *out = op.getDpsInitOperand(0);
227 if (getSparseTensorType(out->get()).isAllDense())
228 return false;
229 return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
230 out->get());
231}
232
233/// Helper method to inspect sparse encodings in the tensor types.
234/// Fills the per-dimension sparsity information for all tensors.
235/// Returns true if the sparse annotations and affine subscript
236/// expressions of all tensors are admissible. Returns false if
237/// no annotations are found or inadmissible constructs occur.
238/// We currently support two different ways to handle non-trivial index
239/// expression on sparse tensors, and they accept different affine expressions.
240/// When using dependent index reducton-based approach, it currently only
241/// supports affine addition index expression.
242static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
243 bool annotated = false;
244 for (OpOperand &t : env.op()->getOpOperands()) {
245 const TensorId tid = env.makeTensorId(t.getOperandNumber());
246 const auto map = env.op().getMatchingIndexingMap(&t);
247 const auto enc = getSparseTensorEncoding(t.get().getType());
248 if (enc)
249 annotated = true;
250 const Level lvlRank = map.getNumResults();
251 assert(!enc || lvlRank == enc.getLvlRank());
252 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
253 // We only need to do index reduction if there is at least one
254 // non-trivial index expression on sparse levels. If all non-trivial
255 // index expression is on dense levels, we can efficiently rely on
256 // the random access to locate the element.
257 bool needIdxReduc =
258 enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
259 // If then current tensor being inspected requires affine index, it need
260 // to be sliced.
261 for (Level l = 0; l < lvlRank; l++) {
262 const AffineExpr a = map.getResult(l);
263 const LevelType lt = enc.getLvlType(l);
264 if (idxReducBased && needIdxReduc) {
265 if (!findDepIdxSet(env.merger(), tid, l, a, lt))
266 return false; // inadmissible affine expression
267 } else {
268 if (!findAffine(env.merger(), tid, l, a, lt))
269 return false; // inadmissible affine expression
270 }
271 }
272 }
273 return annotated;
274}
275
276//===----------------------------------------------------------------------===//
277// Sparsifier synthesis methods (statements and expressions).
278//===----------------------------------------------------------------------===//
279
280/// Local bufferization of all dense and sparse data structures.
281static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
282 linalg::GenericOp op = env.op();
283 Location loc = op.getLoc();
284 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
285
286 SmallVector<Range, 4> loopRange =
287 llvm::cast<linalg::LinalgOp>(op.getOperation())
288 .createLoopRanges(builder, loc);
289
291 builder, loc,
292 /// Generates buffer for the output tensor.
293 /// Note that all sparse kernels assume that when all elements are written
294 /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
295 /// to all zeroes and only nonzeroes values are computed and written out.
296 /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
297 /// for the updates and no assumption on the original contents of the
298 /// output buffer is necessary.
299 [&op](OpBuilder &builder, Location loc, Value memref,
300 Value tensor) -> Value {
301 // Must not be a sparse tensor.
302 assert(!getSparseTensorEncoding(tensor.getType()));
303 // Two output tensor references should point to the same object.
304 OpOperand *lhs = op.getDpsInitOperand(0);
305 assert(lhs->get() == tensor);
306 // An output tensor can simply materialize from the buffer of the tensor
307 // that appears in the outs() clause. For updates, this has the
308 // advantage that only the nonzero value are involved in the
309 // computation, keeping the operation O(nnz). In all other cases, we are
310 // forced to zero out the buffer to enforce the assumption above, which
311 // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
312 // O(nnz) for matrices).
313 // TODO: use better analysis to avoid zeroing out the buffer?
314 bool isInit = op.isInitTensor(lhs);
315 Value init = memref;
316 if (!isInit) {
317 Value zero = constantZero(builder, loc,
318 getElementTypeOrSelf(tensor.getType()));
319 linalg::FillOp::create(builder, loc, ValueRange{zero},
320 ValueRange{init});
321 }
322 return init;
323 },
324 [&loopRange](OpBuilder &b, Location loc, Level l) {
325 assert(l < loopRange.size());
326 return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
327 });
328}
329
330/// Generates index for load/store on sparse tensor.
332 const auto map = env.op().getMatchingIndexingMap(t);
333 const auto stt = getSparseTensorType(t->get());
334 const Level lvlRank = stt.getLvlRank();
335 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
336 const AffineExpr a = map.getResult(lvlRank - 1);
337 assert(a.getKind() == AffineExprKind::DimId);
338 const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
339 return env.getLoopVar(idx);
340}
341
342/// Generates subscript for load/store on a dense or sparse tensor.
345 const Location loc = env.op().getLoc();
346 const TensorId tid = env.makeTensorId(t->getOperandNumber());
347 const auto map = env.op().getMatchingIndexingMap(t);
348 const auto stt = getSparseTensorType(t->get());
349 if (stt.hasEncoding()) {
350 // For sparse tensors we only push the last-level's position onto `args`.
351 const auto pos = env.emitter().getValPosits(tid);
352 assert(!pos.empty());
353 args.append(pos);
354 // Simply returns the tensor to extract value using iterators.
356 return t->get();
357 } else {
358 // For dense tensors we push all level's coordinates onto `args`.
359 const Level lvlRank = stt.getLvlRank();
360 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
361 for (Level l = 0; l < lvlRank; l++) {
362 const auto lvlExpr = map.getResult(l);
363 const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
364 args.push_back(lvlCrd);
365 }
366 }
367 return env.emitter().getValBuffer()[tid];
368}
369
370/// Generates insertion code to implement dynamic tensor load.
372 OpOperand *t) {
373 linalg::GenericOp op = env.op();
374 Location loc = op.getLoc();
375 // Direct lexicographic coordinate order, tensor loads as zero.
376 if (!env.isExpand()) {
378 return constantZero(builder, loc, tp);
379 }
380 // Load from expanded access pattern.
381 Value index = genIndex(env, t);
382 return memref::LoadOp::create(builder, loc, env.getExpandValues(), index);
383}
384
385/// Generates insertion code to implement dynamic tensor load for reduction.
387 OpOperand *t) {
388 linalg::GenericOp op = env.op();
389 Location loc = op.getLoc();
390 Value identity = env.getCustomRedId();
391 // Direct lexicographic coordinate order, tensor loads as identity.
392 if (!env.isExpand())
393 return identity;
394 // Load from expanded access pattern if filled, identity otherwise.
395 Value values = env.getExpandValues();
396 Value filled = env.getExpandFilled();
397 Value index = genIndex(env, t);
398 Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
399 Value valAtIndex = memref::LoadOp::create(builder, loc, values, index);
400 return arith::SelectOp::create(builder, loc, isFilled, valAtIndex, identity);
401}
402
404 Value sparseOut, ValueRange ivs, Value v) {
405 scf::IfOp condInsert =
406 scf::IfOp::create(builder, loc, sparseOut.getType(), cond, true);
407 // True branch.
408 builder.setInsertionPointToStart(condInsert.thenBlock());
409 Value res = tensor::InsertOp::create(builder, loc, v, sparseOut, ivs);
410 scf::YieldOp::create(builder, loc, res);
411 // False branch.
412 builder.setInsertionPointToStart(condInsert.elseBlock());
413 scf::YieldOp::create(builder, loc, sparseOut);
414 // Value assignment.
415 builder.setInsertionPointAfter(condInsert);
416 return condInsert.getResult(0);
417}
418
419/// Generates insertion code to implement dynamic tensor store.
420static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
421 Value rhs) {
422 linalg::GenericOp op = env.op();
423 Location loc = op.getLoc();
424 // Direct insertion in lexicographic coordinate order.
425 if (!env.isExpand()) {
426 const LoopId numLoops = op.getRank(t);
427 // Retrieves the first `numLoop` induction variables.
428 SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
429 env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
430 Value chain = env.getInsertionChain();
431 if (env.isValidLexInsert()) {
432 // Generates runtime check for a valid lex during reduction,
433 // to avoid inserting the identity value for empty reductions.
434 // if (validLexInsert) then
435 // insert(rhs) into chain
436 // return updated chain
437 // else
438 // return unmodified chain
439 Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
440 chain, ivs, rhs);
441 env.updateInsertionChain(out);
442 } else {
443 Value sparseOut;
444 if (!hasAnySparseType(env.op().getInputs().getTypes())) {
445 // This is an all-dense -> sparse kernel, test rhs != 0 before
446 // insertion.
447 Value nz = genIsNonzero(builder, loc, rhs);
448 sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
449 } else {
450 sparseOut = tensor::InsertOp::create(builder, loc, rhs, chain, ivs);
451 }
452 // Generates regular insertion chain.
453 env.updateInsertionChain(sparseOut);
454 }
455 return;
456 }
457 // Generates insertion code along expanded access pattern.
458 // if (!expFilled[i]) then
459 // expFilled[i] = true
460 // expAdded[inserts++] = i
461 // endif
462 // values[i] = rhs
463 Value values = env.getExpandValues();
464 Value filled = env.getExpandFilled();
465 Value added = env.getExpandAdded();
466 Value count = env.getExpandCount();
467 Value index = genIndex(env, t);
468 Value fval = constantI1(builder, loc, false);
469 Value tval = constantI1(builder, loc, true);
470 // If statement.
471 Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
472 Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
473 isFilled, fval);
474 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIndexType(), cond,
475 /*else=*/true);
476 // True branch.
477 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
478 memref::StoreOp::create(builder, loc, tval, filled, index);
479 memref::StoreOp::create(builder, loc, index, added, count);
480 Value one = constantIndex(builder, loc, 1);
481 Value add = arith::AddIOp::create(builder, loc, count, one);
482 scf::YieldOp::create(builder, loc, add);
483 // False branch.
484 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
485 scf::YieldOp::create(builder, loc, count);
486 builder.setInsertionPointAfter(ifOp);
487 // Value assignment.
488 env.updateExpandCount(ifOp.getResult(0));
489 memref::StoreOp::create(builder, loc, rhs, values, index);
490}
491
492/// Generates a load on a dense or sparse tensor.
493static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
494 // Test if the load was hoisted to a higher loop nest.
495 Value val = env.exp(exp).val;
496 if (val)
497 return val;
498 // Get tensor operand.
499 linalg::GenericOp op = env.op();
500 Location loc = op.getLoc();
501 OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
502 // Fold binary-valued tensor into explicit value.
503 const auto stt = getSparseTensorType(t->get());
504 if (auto explVal = stt.getExplicitVal())
505 return genValFromAttr(builder, loc, explVal);
506 // Load during insertion.
507 if (env.isSparseOutput(t)) {
508 if (env.isCustomReduc())
509 return genInsertionLoadReduce(env, builder, t);
510 return genInsertionLoad(env, builder, t);
511 }
512
513 // Actual load.
515 Value ptr = genSubscript(env, builder, t, args);
516 if (llvm::isa<TensorType>(ptr.getType())) {
517 assert(env.options().sparseEmitStrategy ==
519 return ExtractValOp::create(builder, loc, ptr,
520 llvm::getSingleElement(args));
521 }
522 return memref::LoadOp::create(builder, loc, ptr, args);
523}
524
525/// Generates a store on a dense or sparse tensor.
526static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
527 Value rhs) {
528 // Only unary and binary are allowed to return an uninitialized rhs
529 // to indicate missing output. Or otherwise a custom reduction that
530 // received no value to accumulate.
531 if (!rhs) {
532 assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
533 env.exp(exp).kind == TensorExp::Kind::kBinary ||
534 env.exp(exp).kind == TensorExp::Kind::kReduce);
535 return;
536 }
537 // Test if this is a scalarized reduction.
538 if (env.isReduc()) {
539 env.updateReduc(rhs);
540 return;
541 }
542 // Regular store.
543 linalg::GenericOp op = env.op();
544 Location loc = op.getLoc();
545 OpOperand *t = op.getDpsInitOperand(0);
546 if (!env.isSparseOutput(t)) {
548 Value ptr = genSubscript(env, builder, t, args);
549 memref::StoreOp::create(builder, loc, rhs, ptr, args);
550 return;
551 }
552 // Store during sparse insertion.
553 if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
554 genInsertionStore(env, builder, t, rhs);
555 return;
556 }
557 // Select operation insertion.
558 Value chain = env.getInsertionChain();
559 scf::IfOp ifOp =
560 scf::IfOp::create(builder, loc, chain.getType(), rhs, /*else=*/true);
561 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
562 // Existing value was preserved to be used here.
563 assert(env.exp(exp).val);
564 Value v0 = env.exp(exp).val;
565 genInsertionStore(env, builder, t, v0);
566 env.merger().clearExprValue(exp);
567 // Yield modified insertion chain along true branch.
568 Value mchain = env.getInsertionChain();
569 scf::YieldOp::create(builder, op.getLoc(), mchain);
570 // Yield original insertion chain along false branch.
571 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
572 scf::YieldOp::create(builder, loc, chain);
573 // Done with if statement.
574 env.updateInsertionChain(ifOp->getResult(0));
575 builder.setInsertionPointAfter(ifOp);
576}
577
578/// Generates an invariant value.
579inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
580 return env.exp(exp).val;
581}
582
583/// Semi-ring branches are simply inlined by the sparsifier. Prior
584/// analysis has verified that all computations are "local" to the inlined
585/// branch or otherwise invariantly defined outside the loop nest, with the
586/// exception of index computations, which need to be relinked to actual
587/// inlined cloned code.
588static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
589 Value e) {
590 if (auto arg = dyn_cast<BlockArgument>(e)) {
591 // Direct arguments of the original linalg op must be converted into
592 // tensor element loads. This handles both dense tensor loads (using
593 // current loop coordinates) and sparse tensor loads (using the current
594 // value position tracked by the loop emitter).
595 linalg::GenericOp op = env.op();
596 if (arg.getOwner()->getParentOp() == op) {
597 const TensorId tid = env.makeTensorId(arg.getArgNumber());
598 OpOperand *t = &op->getOpOperand(tid);
600 Value ptr = genSubscript(env, rewriter, t, args);
601 Location loc = op.getLoc();
602 if (llvm::isa<TensorType>(ptr.getType())) {
603 // kSparseIterator strategy: extract value at the iterator position.
604 assert(env.options().sparseEmitStrategy ==
606 return ExtractValOp::create(rewriter, loc, ptr,
607 llvm::getSingleElement(args));
608 }
609 return memref::LoadOp::create(rewriter, loc, ptr, args);
610 }
611 } else if (Operation *def = e.getDefiningOp()) {
612 // Handle index computation.
613 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
614 return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
615 // When still defined in new body, recurse into operands.
616 if (def->getBlock() == block) {
617 rewriter.setInsertionPoint(def);
618 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
619 rewriter.modifyOpInPlace(def, [&]() {
620 def->setOperand(
621 i, relinkBranch(env, rewriter, block, def->getOperand(i)));
622 });
623 }
624 }
625 }
626 return e;
627}
628
629/// Recursively generates tensor expression.
630static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
632 return Value();
633
634 linalg::GenericOp op = env.op();
635 Location loc = op.getLoc();
636 const TensorExp &exp = env.exp(e);
637 const auto kind = exp.kind;
638 if (kind == TensorExp::Kind::kTensor)
639 return genTensorLoad(env, rewriter, e);
640 if (kind == TensorExp::Kind::kInvariant)
641 return genInvariantValue(env, e);
642 if (kind == TensorExp::Kind::kLoopVar)
643 return env.getLoopVar(exp.loop);
644
645 if (kind == TensorExp::Kind::kReduce)
646 env.startCustomReduc(e); // enter custom
647
648 // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
649 // based on the type of the other operand.
650 Value v0, v1;
653 v1 = genExp(env, rewriter, exp.children.e1);
654 v0 = constantZero(rewriter, loc, v1.getType());
657 v0 = genExp(env, rewriter, exp.children.e0);
658 v1 = constantZero(rewriter, loc, v0.getType());
659 } else {
660 v0 = genExp(env, rewriter, exp.children.e0);
661 v1 = genExp(env, rewriter, exp.children.e1);
662 }
663
664 Value ee;
665 if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
666 // custom reduce did not receive a value
667 } else {
668 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
669 if (ee &&
672 kind == TensorExp::Kind::kReduce ||
673 kind == TensorExp::Kind::kSelect)) {
674 OpBuilder::InsertionGuard guard(rewriter);
675 ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
676 }
677 }
678
679 if (kind == TensorExp::Kind::kReduce)
680 env.endCustomReduc(); // exit custom
681
682 if (kind == TensorExp::Kind::kSelect)
683 env.merger().setExprValue(e, v0); // Preserve value for later use.
684
685 return ee;
686}
687
688/// Hoists loop invariant tensor loads for which indices have been exhausted.
689static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
690 LoopId curr, bool isStart) {
692 return;
693 if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
694 // Inspect tensor indices.
695 linalg::GenericOp op = env.op();
696 OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
697 const auto map = op.getMatchingIndexingMap(&t);
698 const auto stt = getSparseTensorType(t.get());
699 const Level lvlRank = stt.getLvlRank();
700 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
701 bool isCurrentLoop = curr == 0; // for scalar tensors
702 for (Level l = 0; l < lvlRank; l++) {
703 const AffineExpr a = map.getResult(l);
704 if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
705 return; // still in play
706 }
707 // All exhausted at current level.
708 if (!isCurrentLoop)
709 return;
710 // Generate code for a scalarized reduction or invariant. Note that
711 // because custom reduction lhs may occur several times in the IR,
712 // we have a built-in safety for only initializing and wrapping-up
713 // the scalarized reduction once.
714 OpOperand *lhs = op.getDpsInitOperand(0);
715 if (lhs == &t) {
716 // Start or end a scalarized reduction.
717 if (isStart) {
718 if (env.isCustomReduc()) {
719 if (!env.isReduc())
720 env.startReduc(exp, env.getCustomRedId());
721 } else {
722 env.startReduc(exp, genTensorLoad(env, builder, exp));
723 }
724 if (env.hasSparseOutput())
726 constantI1(builder, env.op().getLoc(), false));
727 } else {
728 if (!env.isCustomReduc() || env.isReduc())
729 genTensorStore(env, builder, exp, env.endReduc());
730 if (env.hasSparseOutput())
731 env.endValidLexInsert();
732 }
733 } else {
734 // Start or end loop invariant hoisting of a tensor load.
735 if (isStart) {
736 env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
737 } else {
738 env.merger().clearExprValue(exp);
739 }
740 }
741 } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
742 env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
743 env.exp(exp).kind != TensorExp::Kind::kSynZero) {
744 // Traverse into the binary operations. Note that we only hoist
745 // tensor loads, since subsequent MLIR/LLVM passes know how to
746 // deal with all other kinds of derived loop invariants.
747 if (env.exp(exp).kind == TensorExp::Kind::kReduce)
748 env.startCustomReduc(exp); // enter custom
749 const ExprId e0 = env.exp(exp).children.e0;
750 const ExprId e1 = env.exp(exp).children.e1;
751 genInvariants(env, builder, e0, curr, isStart);
752 genInvariants(env, builder, e1, curr, isStart);
753 if (env.exp(exp).kind == TensorExp::Kind::kReduce)
754 env.endCustomReduc(); // exit custom
755 }
756}
757
758/// Generates an expanded access pattern in innermost dimension.
759static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
760 bool isStart) {
761 linalg::GenericOp op = env.op();
762 OpOperand *lhs = op.getDpsInitOperand(0);
763 if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
764 return; // not needed at current level
765 assert(!env.isReduc());
766 // Generate start or end of an expanded access pattern. Note that because
767 // an expansion does not rely on the ongoing contents of the sparse storage
768 // scheme, we can use the original tensor as incoming SSA value (which
769 // simplifies codegen a bit). If expansion on the actual contents is ever
770 // needed, we will need to use the SSA value in the insertion chain instead.
771 Value tensor = lhs->get();
772 Location loc = op.getLoc();
773 if (isStart) {
774 auto dynShape = {ShapedType::kDynamic};
775 Type etp = cast<ShapedType>(tensor.getType()).getElementType();
776 Type t1 = MemRefType::get(dynShape, etp);
777 Type t2 = MemRefType::get(dynShape, builder.getI1Type());
778 Type t3 = MemRefType::get(dynShape, builder.getIndexType());
779 Type t4 = builder.getIndexType();
780 auto r =
781 ExpandOp::create(builder, loc, TypeRange({t1, t2, t3, t4}), tensor);
782 assert(r.getNumResults() == 4);
783 env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
784 r.getResult(3));
785 } else {
787 for (LoopId i = 0; i < curr; i++)
788 indices.push_back(env.emitter().getLoopIV(i));
789 Value values = env.getExpandValues();
790 Value filled = env.getExpandFilled();
791 Value added = env.getExpandAdded();
792 Value count = env.getExpandCount();
793 Value chain = env.getInsertionChain();
794 Value compress = CompressOp::create(builder, loc, values, filled, added,
795 count, chain, indices);
796 env.updateInsertionChain(compress);
797 env.endExpand();
798 }
799}
800
801/// Returns parallelization strategy. Any implicit loop in the Linalg
802/// operation that is marked "parallel" is a candidate. Whether it is actually
803/// converted to a parallel operation depends on the requested strategy.
804static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
805 // Reject parallelization of sparse output.
806 if (env.hasSparseOutput())
807 return false;
808 // Parallel loops on tensor expansion can cause data races.
809 if (env.isExpand())
810 return false;
811 // Inspect strategy.
812 switch (env.options().parallelizationStrategy) {
814 return false;
816 return isOuter && !isSparse;
818 return isOuter;
820 return !isSparse;
822 return true;
823 }
824 llvm_unreachable("unexpected parallelization strategy");
825}
826
827/// Whether or not the current loop being generated should be parallized (if
828/// possible) according to the configuration.
829static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
830 ArrayRef<TensorLevel> tidLvls) {
831 linalg::GenericOp op = env.op();
832 auto iteratorTypes = op.getIteratorTypesArray();
833 bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
834 // Queries the LT based on the tensor and loop id, as requested by
835 // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
836 // should be consistent with the LT indexed by <TensorId, Level>.
837 const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
838 return lt.hasSparseSemantic();
839 });
840 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
841}
842
843/// Emit a loop to coiterate over the list of tensor levels. The generated loop
844/// can either be a for loop or while loop depending on whether there is at most
845/// one sparse level in the list.
847 ArrayRef<TensorLevel> tidLvls,
848 unsigned numCases, bool tryParallel,
849 bool needsUniv) {
850 Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
851 // Construct while-loop with a parameter for each index.
853 builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
854 needsUniv);
855 });
856 assert(loop);
857 return loop;
858}
859
860/// Generates a for-loop or a while-loop, depending on whether it implements
861/// singleton iteration or co-iteration over the given conjunction.
862static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
863 unsigned numCases, bool needsUniv,
864 ArrayRef<TensorLevel> tidLvls) {
865 bool tryParallel = shouldTryParallize(env, curr, tidLvls);
866 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
867 needsUniv);
868}
869
870/// Generates the induction structure for a while-loop.
871static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
872 bool needsUniv) {
873 Location loc = env.op().getLoc();
874 // Finalize each else branch of all if statements.
875 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
876 while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
877 builder.getInsertionBlock()->getParentOp())) {
878 // Break on IfOp for slicing filtering.
879 if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
880 StringAttr::get(ifOp->getContext(), "slice"))
881 break;
882
883 unsigned y = 0;
884 SmallVector<Value> yields;
885 if (env.isReduc()) {
886 yields.push_back(env.getReduc());
887 env.updateReduc(ifOp.getResult(y++));
888 if (env.isValidLexInsert()) {
889 yields.push_back(env.getValidLexInsert());
890 env.updateValidLexInsert(ifOp.getResult(y++));
891 }
892 }
893 if (env.isExpand()) {
894 yields.push_back(env.getExpandCount());
895 env.updateExpandCount(ifOp->getResult(y++));
896 }
897 if (env.getInsertionChain()) {
898 yields.push_back(env.getInsertionChain());
899 env.updateInsertionChain(ifOp->getResult(y++));
900 }
901 assert(y == yields.size());
902 scf::YieldOp::create(builder, loc, yields);
903 builder.setInsertionPointAfter(ifOp);
904 }
905 }
906 // No need to set the insertion point here as LoopEmitter keeps track of the
907 // basic block where scf::Yield should be inserted.
908}
909
910/// Generates a case region in the coiterate operation.
911static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
912 unsigned caseIdx, LatPointId allCase,
913 LatPointId curCase,
915 assert(allCase == curCase || env.merger().latGT(allCase, curCase));
916 const BitVector &allCaseBits = env.merger().lat(allCase).simple;
917 const BitVector &curCaseBits = env.merger().lat(curCase).simple;
918
919 /// Computes the subset of iterators that are valid in the current case being
920 /// generated.
921 I64BitSet caseBit(0);
922 for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
923 if (curCaseBits.test(set))
924 caseBit.set(idx);
925
926 env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
927 caseIdx, reduc);
928}
929
930/// Generates a single if-statement within a while-loop.
931static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
932 LatPointId p) {
933 Location loc = env.op().getLoc();
934 SmallVector<Type> types;
935 Value cond;
937 p, /*simple=*/true,
938 [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
939 bool isIdxRed) {
940 if (isIdxRed) {
941 // Since there is no 1:1 mapping from loop to level (multiple loops
942 // are required to resolve one level with non-trivial index
943 // expression), we need to reconstruct the tensor level types if this
944 // loop requires index reduction condition.
945 assert(lvl.has_value() && isUndefLT(lt));
946 auto stt = getSparseTensorType(env.op().getInputs()[tid]);
947 lt = stt.getLvlType(*lvl);
948 }
949 assert(curr == env.merger().loop(b));
950 Value clause;
951 if (lt.hasSparseSemantic()) {
952 assert(lvl.has_value());
953 const Value crd = env.emitter().getCoord(tid, *lvl);
954 const Value lvar = env.getLoopVar(curr);
955 clause = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
956 crd, lvar);
957 } else {
958 assert(lt.hasDenseSemantic() || isUndefLT(lt));
959 clause = constantI1(builder, loc, true);
960 }
961 cond =
962 cond ? arith::AndIOp::create(builder, loc, cond, clause) : clause;
963 });
964 if (env.isReduc()) {
965 types.push_back(env.getReduc().getType());
966 if (env.isValidLexInsert())
967 types.push_back(env.getValidLexInsert().getType());
968 }
969 if (env.isExpand())
970 types.push_back(builder.getIndexType());
971 if (env.getInsertionChain())
972 types.push_back(env.getInsertionChain().getType());
973 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
974 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
975 return ifOp;
976}
977
978/// Generates end of true branch of if-statement within a while-loop.
979static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
980 Value redInput, Value cntInput, Value insInput,
981 Value validIns) {
982 SmallVector<Value> operands;
983 if (env.isReduc()) {
984 operands.push_back(env.getReduc());
985 env.updateReduc(redInput);
986 if (env.isValidLexInsert()) {
987 // Any overlapping indices during a reduction creates a valid lex insert.
988 operands.push_back(constantI1(builder, env.op().getLoc(), true));
989 env.updateValidLexInsert(validIns);
990 }
991 }
992 if (env.isExpand()) {
993 operands.push_back(env.getExpandCount());
994 env.updateExpandCount(cntInput);
995 }
996 if (env.getInsertionChain()) {
997 operands.push_back(env.getInsertionChain());
998 env.updateInsertionChain(insInput);
999 }
1000 if (!operands.empty())
1001 scf::YieldOp::create(builder, env.op().getLoc(), operands);
1002 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1003}
1004
1005//===----------------------------------------------------------------------===//
1006// Sparsifier synthesis methods (loop sequence).
1007//===----------------------------------------------------------------------===//
1008
1010 CodegenEnv &env, LatPointId li, LoopId curr,
1011 llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1012 const BitVector &simple = env.lat(li).simple;
1013 const TensorId outTid = env.merger().getOutTensorID();
1014 const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1015
1016 unsigned numloopCond = 0;
1017 bool hasNonUnique = false;
1019 li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1020 LevelType lt, bool isIdxReduc) {
1021 if (simple[b]) {
1022 if (isIdxReduc) {
1023 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1024 numloopCond++;
1025 return;
1026 }
1027 if (isUndefLT(lt)) {
1028 // An undefined lt in the lattices, we probably mean to
1029 // generate a dense loop according to the synthetic tensor (for
1030 // invariants and sparse output tensor).
1031 if (env.merger().getSynTensorID() == tid) {
1032 // Coiterating with an invariant
1033 // e.g., out = prod(in[i][j] op invariant);
1034 // or a broadcast
1035 // e.g., out[i][j] = in[i] (j is undef for input)
1036 //
1037 // The level of the synthetic tensor is the current loop depth;
1038 // the rank of the synthetic tensor equals to number of loops.
1039 assert(curr == env.getCurrentDepth());
1040 lvl = curr;
1041 } else if (!lvl) {
1042 // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1043 return;
1044 }
1045 }
1046 hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1047 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1048 numloopCond++;
1049 } else if (lt.hasDenseSemantic() || isIdxReduc) {
1050 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1051 } else {
1052 assert(isUndefLT(lt));
1053 linalg::GenericOp op = env.op();
1054 if (tid >= op.getNumDpsInputs())
1055 // We only handle affine expression on input tensors (for now).
1056 return;
1057 OpOperand *operand = &op->getOpOperand(tid);
1058 const auto stt = getSparseTensorType(operand->get());
1059 // Non-annotated dense tensors requires no special handling.
1060 if (!stt.hasEncoding())
1061 return;
1062
1063 ArrayRef<AffineExpr> affines =
1064 op.getMatchingIndexingMap(operand).getResults();
1065 const Level lvlRank = stt.getLvlRank();
1066 assert(affines.size() == static_cast<size_t>(lvlRank));
1067 for (Level l = 0; l < lvlRank; l++) {
1068 AffineExpr exp = affines[l];
1069 // Skip simple affine expression and non-dense levels (which
1070 // have their own filter loop).
1071 LevelType lt = stt.getLvlType(l);
1072 if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1073 continue;
1074
1075 // Constant affine expression are handled in genLoop.
1076 if (!isa<AffineConstantExpr>(exp)) {
1077 bool isCurrentLoop = false;
1078 assert(curr == env.getCurrentDepth());
1079 if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1080 isCurrentLoop) {
1081 // If the compound affine is invariant and we are right at the
1082 // level. We need to generate the address according to the
1083 // affine expression. This is also the best place we can do it
1084 // to avoid putting it inside inner loops.
1085 callback(env.makeTensorLevel(tid, l), exp);
1086 }
1087 }
1088 }
1089 }
1090 });
1091
1092 if (isDenseLT(env.lt(outTid, curr))) {
1093 auto stt = getSparseTensorType(env.op().getOutputs().front());
1094 // Note that we generate dense indices of the output tensor unconditionally,
1095 // since they may not appear in the lattice, but may be needed for
1096 // linearized env.
1097 // TODO: we should avoid introducing corner cases for all-dense sparse
1098 // tensors.
1099 if (stt.hasEncoding() && stt.isAllDense())
1100 callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1101 }
1102
1103 if (numloopCond == 0) {
1104 // Corner cases where the loop bound is defined by a *unused* operand, in
1105 // this case, we just generate a dense "fake" loop by iterating over the
1106 // synthetic tensor.
1107 callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1108 numloopCond++;
1109 }
1110 // If we just need to one loop conditions and the conditions is not imposed on
1111 // non-unique level, the loop can be generated by a for loop.
1112 // Or, if we are generating sparse-iterator-based loops, we always generate
1113 // `sparse_tensor.iterate` regardless whether the level is unique or not.
1114 return numloopCond == 1 &&
1115 (!hasNonUnique || env.options().sparseEmitStrategy ==
1117}
1118
1119/// Starts a loop sequence at given level. Returns true if
1120/// the universal loop index must be maintained at this level.
1121static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1122 LoopId curr, LatSetId lts) {
1123 assert(!env.getLoopVar(curr));
1124 // Emit invariants at this loop sequence level.
1125 genInvariants(env, builder, exp, curr, /*isStart=*/true);
1126 // Emit access pattern expansion for sparse tensor output.
1127 genExpand(env, builder, curr, /*isStart=*/true);
1128 // Emit further initialization at this loop sequence level.
1129 const LatPointId l0 = env.set(lts)[0];
1130
1132 getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1133 // TODO: remove this! The same tensor level might be added for multiple
1134 // times due to the special handling for all-dense "sparse" output tensor
1135 // (see L1038).
1136 if (llvm::is_contained(tidLvls, tl))
1137 return;
1138 tidLvls.emplace_back(tl);
1139 });
1140
1141 env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1142
1143 // Maintain the universal index only if it is actually
1144 // consumed by a subsequent lattice point.
1145 for (const LatPointId li : env.set(lts).drop_front())
1146 if (!env.merger().hasAnySparse(env.lat(li).simple))
1147 return true;
1148
1149 return false;
1150}
1151
1152// Generates dense affine address for encoding.
1154 OpBuilder &builder, TensorId tid,
1155 Level startLvl) {
1156 // TODO: Handle affine expression on output tensor.
1157 linalg::GenericOp op = env.op();
1158 assert(tid < op.getNumDpsInputs());
1159 OpOperand *input = op.getDpsInputOperands()[tid];
1160 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1161 const auto enc = getSparseTensorEncoding(input->get().getType());
1162 if (enc) {
1163 const Location loc = op.getLoc();
1164 const TensorId tid = env.makeTensorId(input->getOperandNumber());
1165 const Level lvlRank = enc.getLvlRank();
1166 assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1167 for (Level l = startLvl; l < lvlRank; l++) {
1168 AffineExpr lvlExpr = lvlExprs[l];
1169 if (enc.getLvlType(l).hasDenseSemantic() &&
1170 isa<AffineConstantExpr>(lvlExpr))
1172 builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1173 else
1174 return; // break on first non-dense non-constant level
1175 }
1176 }
1177}
1178
1179// We can generate address for constant affine expression before any loops
1180// starting from the first level as they do not depend on anything.
1181// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1182// levels can be determined before loops.
1184 RewriterBase &rewriter) {
1185 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1186 genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1187}
1188
1189/// Returns true if the lattice bit can be iterated by a for loop.
1191 CodegenEnv &env, LatPointId li, LoopId curr,
1193 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1194 return getAllTidLvlsInLatPoints(env, li, curr,
1195 [&](TensorLevel tl, AffineExpr exp) {
1196 if (exp)
1197 affineTidLvls.emplace_back(tl, exp);
1198 else
1199 tidLvls.emplace_back(tl);
1200 });
1201}
1202
1203/// Starts a single loop in current sequence.
1204static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1205 OpBuilder &builder, LoopId curr,
1206 LatPointId li, unsigned numCases,
1207 bool needsUniv) {
1208 // TODO: numCases only used when generating iterator-based loops. Cleanup
1209 // after fully migration.
1210 // The set of tensors + lvls to generate loops on
1212
1213 // The set of dense tensors with non-trivial affine expression that just
1214 // becomes invariant and the address are generated at the current level.
1216 bool isSingleCond =
1217 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1218
1219 // Emit the for/while-loop control.
1220 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1221 Location loc = env.op().getLoc();
1222 for (auto [tidLvl, exp] : affineTidLvls) {
1223 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1224 }
1225
1226 // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1227 // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1228 // on constant affines expression may now be determined.
1229 auto allTidLvls =
1230 llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1231 for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1232 if (tid != env.merger().getOutTensorID() &&
1233 tid != env.merger().getSynTensorID())
1234 genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1235 }
1236
1237 return std::make_pair(loop, isSingleCond);
1238}
1239
1240/// Ends a single loop in current sequence. Returns new values for needsUniv.
1241static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1242 LatPointId li, bool needsUniv, bool isSingleCond) {
1243 // Either a for-loop or a while-loop that iterates over a slice.
1244 if (isSingleCond) {
1245 // Any iteration creates a valid lex insert.
1246 if (env.isReduc() && env.isValidLexInsert())
1247 env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1248 } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1249 // End a while-loop.
1250 finalizeWhileOp(env, rewriter, needsUniv);
1251 } else {
1252 needsUniv = false;
1253 }
1255 env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1256 return std::nullopt;
1257 });
1258 return needsUniv;
1259}
1260
1261/// Ends a loop sequence at given level.
1262static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1263 unsigned at) {
1264 assert(!env.getLoopVar(at));
1265 env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1266 // Unmark bookkeeping of invariants and loop index.
1267 genInvariants(env, builder, exp, at, /*isStart=*/false);
1268 // Finalize access pattern expansion for sparse tensor output.
1269 genExpand(env, builder, at, /*isStart=*/false);
1270}
1271
1272/// Recursively generates code while computing iteration lattices in order
1273/// to manage the complexity of implementing co-iteration over unions
1274/// and intersections of sparse iterations spaces.
1275static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1276 LoopId curr) {
1277 assert(curr == env.getCurrentDepth());
1278
1279 // At each leaf, assign remaining tensor (sub)expression to output tensor.
1280 if (curr == env.getLoopNum()) {
1281 Value rhs = genExp(env, rewriter, exp);
1282 genTensorStore(env, rewriter, exp, rhs);
1283 return;
1284 }
1285
1286 // Construct iteration lattices for current loop index.
1287 const LatSetId lts =
1288 env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1289
1290 // Start a loop sequence.
1291 bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1292
1293 // When using sparse-iterator-based loops, we only need one loops, as
1294 // opposed to a loop sequence, to cover all the iterator spaces.
1295 const unsigned lsize = env.set(lts).size();
1296 if (env.generatingSparseIterator()) {
1297 // Get the largest lattice point and start a loop.
1298 const LatPointId li = env.set(lts)[0];
1299 auto [loop, isSingleCond] =
1300 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1301 assert(isSingleCond == llvm::isa<IterateOp>(loop));
1302 // We cannot change this to `for (const LatPointId li : env.set(lts))`
1303 // because the loop body causes data-movement which invalidates
1304 // the iterator.
1305 for (unsigned j = 0; j < lsize; j++) {
1306 const LatPointId lj = env.set(lts)[j];
1307 const ExprId ej = env.lat(lj).exp;
1308 // Recurse into body of each branch.
1309 if (!isSingleCond) {
1310 env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1311 genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1312 genStmt(env, rewriter, ej, curr + 1);
1313 // TODO: handle yield values.
1314 assert(reduc.empty() && "Not Implemented");
1315 sparse_tensor::YieldOp::create(rewriter, env.op().getLoc());
1316 return std::nullopt;
1317 });
1318 // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1319 } else {
1320 genStmt(env, rewriter, ej, curr + 1);
1321 }
1322 }
1323 // End a loop.
1324 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1325 } else {
1326 // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1327 for (unsigned i = 0; i < lsize; i++) {
1328 const LatPointId li = env.set(lts)[i];
1329 // Start a loop.
1330 auto [loop, isSingleCond] =
1331 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1332
1333 // Visit all lattices points with Li >= Lj to generate the
1334 // loop-body, possibly with if statements for coiteration.
1335 Value redInput = env.getReduc();
1336 Value cntInput = env.getExpandCount();
1337 Value insInput = env.getInsertionChain();
1338 Value validIns = env.getValidLexInsert();
1339 // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1340 // because the loop body causes data-movement which invalidates the
1341 // iterator.
1342 for (unsigned j = 0; j < lsize; j++) {
1343 const LatPointId lj = env.set(lts)[j];
1344 const ExprId ej = env.lat(lj).exp;
1345 if (li == lj || env.merger().latGT(li, lj)) {
1346 // Recurse into body of each branch.
1347 if (!isSingleCond) {
1348 scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1349 genStmt(env, rewriter, ej, curr + 1);
1350 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1351 } else {
1352 genStmt(env, rewriter, ej, curr + 1);
1353 }
1354 }
1355 }
1356
1357 // End a loop.
1358 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1359 }
1360 }
1361
1362 // End a loop sequence.
1363 endLoopSeq(env, rewriter, exp, curr);
1364 assert(curr == env.getCurrentDepth());
1365}
1366
1367/// Converts the result computed by the sparse kernel into the required form.
1368static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1369 linalg::GenericOp op = env.op();
1370 OpOperand *lhs = op.getDpsInitOperand(0);
1371 Value tensor = lhs->get();
1372 Type resType = tensor.getType();
1373 if (getSparseTensorEncoding(resType)) {
1374 // The sparse tensor rematerializes from the original sparse tensor's
1375 // underlying sparse storage format. For an insertion chain, the
1376 // tensor materializes from the chain with 'hasInserts' enabled.
1377 bool hasInserts = false;
1378 if (Value chain = env.getInsertionChain()) {
1379 hasInserts = true;
1380 tensor = chain;
1381 }
1382 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1383 } else {
1384 // To rematerialize an non-annotated tensor, simply load it
1385 // from the bufferized value.
1386 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1387 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1388 }
1389}
1390
1391//===----------------------------------------------------------------------===//
1392// Sparsifier rewriting methods.
1393//===----------------------------------------------------------------------===//
1394
1395namespace {
1396
1397/// Sparse rewriting rule for generic Lingalg operation.
1398struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1399public:
1400 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1401 : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1402
1403 LogicalResult matchAndRewrite(linalg::GenericOp op,
1404 PatternRewriter &rewriter) const override {
1405 // Only accept single output operations with pure tensor semantics.
1406 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1407 return failure();
1408
1409 // Only accept trivial affine indices.
1411 return failure();
1412
1413 // Only accept scheduled loops.
1414 if (!op->hasAttr("sorted")) {
1415 return rewriter.notifyMatchFailure(
1416 op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1417 "before sparsification.");
1418 }
1419
1420 // Must have been demapped as well if the generic op is sorted.
1422
1423 // Sets up a code generation environment.
1424 const unsigned numTensors = op->getNumOperands();
1425 const unsigned numLoops = op.getNumLoops();
1426 bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1427 // If we have indexing map like (d0) -> (0, d0), there might be more
1428 // levels then loops because of the constant index, that means we can not
1429 // use numLoops as the upper bound for ranks of all tensors.
1430 // TODO: Constant indices are currently not support on sparse tensor, but
1431 // are allowed in non-annotated dense tensor. Support it, it would be
1432 // required for sparse tensor slice rank reducing too.
1433 Level maxLvlRank = 0;
1434 for (auto operand : op.getOperands()) {
1435 if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1436 maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1437 }
1438 }
1439
1440 // Detects sparse annotations and translates the per-level sparsity
1441 // information for all tensors to loop indices in the kernel.
1442 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1443 if (!findSparseAnnotations(env, needIdxRed))
1444 return failure();
1445
1446 // Only standard reduction operations (add, sub, or, xor) that can be
1447 // sparsified by merely reducing the stored values are admissible. More
1448 // elaborate reduction operations (such as mul, and, min, max) would need
1449 // to know whether implicit zeros occur as well. They can still be
1450 // implemented with a custom reduction operation, accepted here as well.
1451 if (op.getNumReductionLoops() > 0) {
1452 Operation *yield = op.getRegion().front().getTerminator();
1453 assert(isa<linalg::YieldOp>(yield));
1454 Operation *redop = yield->getOperand(0).getDefiningOp();
1455 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1456 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1457 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1458 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1459 !isa<ReduceOp>(redop)) {
1460 return failure();
1461 }
1462 }
1463
1464 // Constructs the tensor expressions tree from `op`, returns failure if the
1465 // tree can not be built or the tensor expression is inadmissible.
1466 if (failed(env.initTensorExp()))
1467 return failure();
1468
1469 // Recursively generates code if admissible.
1470 env.startEmit(options.sparseEmitStrategy);
1471 genBuffers(env, rewriter);
1472 // TODO: Constant affine expression should be handled differently when using
1473 // slice-based codegen, it does not matter now because we already reject the
1474 // constant expression at an earlier stage.
1475 genInitConstantDenseAddress(env, rewriter);
1476 genStmt(env, rewriter, env.getExprId(), 0);
1477 genResult(env, rewriter);
1478 return success();
1479 }
1480
1481private:
1482 /// Options to control sparse code generation.
1483 SparsificationOptions options;
1484};
1485
1486} // namespace
1487
1488/// Populates the given patterns list with rewriting rules required for
1489/// the sparsification of linear algebra operations.
1492 patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1493}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr > > &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static Operation * genCoIteration(CodegenEnv &env, OpBuilder &builder, ArrayRef< TensorLevel > tidLvls, unsigned numCases, bool tryParallel, bool needsUniv)
Emit a loop to coiterate over the list of tensor levels.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)
Starts a single loop in current sequence.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)
Generates a case region in the coiterate operation.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
#define add(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
The code generation environment class aggregates a number of data structures that are needed during t...
Definition CodegenEnv.h:35
void startReduc(ExprId exp, Value val)
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
ArrayRef< LatPointId > set(LatSetId s) const
Definition CodegenEnv.h:83
unsigned getCurrentDepth() const
Definition CodegenEnv.h:115
TensorLevel makeTensorLevel(TensorId t, Level l) const
Definition CodegenEnv.h:95
constexpr TensorId makeTensorId(unsigned t) const
Definition CodegenEnv.h:72
void startExpand(Value values, Value filled, Value added, Value count)
void updateInsertionChain(Value chain)
bool generatingSparseIterator() const
Definition CodegenEnv.h:52
linalg::GenericOp op() const
Definition CodegenEnv.h:50
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
void startEmit(SparseEmitStrategy emitStrategy)
auto unpackTensorLevelRange(ContainerTy &&c) const
Definition CodegenEnv.h:111
void updateExpandCount(Value count)
bool isSparseOutput(OpOperand *o) const
Definition CodegenEnv.h:133
constexpr LoopId makeLoopId(unsigned i) const
Definition CodegenEnv.h:75
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Definition CodegenEnv.h:107
const TensorExp & exp(ExprId e) const
Definition CodegenEnv.h:81
const SparsificationOptions & options() const
Definition CodegenEnv.h:51
LevelType lt(TensorId t, LoopId i) const
Definition CodegenEnv.h:84
const LatPoint & lat(LatPointId l) const
Definition CodegenEnv.h:82
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
I64BitSet & set(unsigned i)
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast<AffineDimExpr>...
static constexpr llvm::StringLiteral getLoopEmitterLoopAttrName()
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
const std::vector< Value > & getValBuffer() const
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
auto getLoopIVsRange() const
Get the range of values for all induction variables.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
A class to handle all iteration lattice operations.
Definition Merger.h:225
const LatPoint & lat(LatPointId p) const
Definition Merger.h:545
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition Merger.cpp:940
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition Merger.h:249
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition Merger.h:426
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition Merger.h:442
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition Merger.h:416
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition Merger.cpp:431
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition Merger.h:472
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition Merger.h:567
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition Merger.h:348
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition Merger.h:363
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition Merger.cpp:1618
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition Merger.h:481
A wrapper around RankedTensorType, which has three goals:
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition Merger.h:30
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
unsigned LatPointId
LatPoint identifiers.
Definition Merger.h:52
unsigned ExprId
TensorExp identifiers.
Definition Merger.h:48
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition Merger.h:35
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
Definition Enums.h:412
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition Merger.h:44
bool isDenseLT(LevelType lt)
Definition Enums.h:413
unsigned LoopId
Loop identifiers.
Definition Merger.h:38
unsigned LatSetId
LatSet identifiers.
Definition Merger.h:57
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ Constant
Constant integer.
Definition AffineExpr.h:57
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options for the Sparsification pass.
Definition Passes.h:109
SparseEmitStrategy sparseEmitStrategy
Definition Passes.h:123
SparseParallelizationStrategy parallelizationStrategy
Definition Passes.h:122
ExprId exp
Identifier of the tensor expression.
Definition Merger.h:218
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition Merger.h:215
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
Definition Enums.h:337
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition Enums.h:343
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition Merger.h:105
Children children
All other expressions hold the ExprIds of their children.
Definition Merger.h:99
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition Merger.h:93
Kind kind
Tensor expression kind.
Definition Merger.h:89
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.