MLIR 22.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor types and primitives to actual compiler
10// visible buffers and actual compiler IR that implements these primitives on
11// the selected sparse tensor storage schemes. This pass provides an alternative
12// to the SparseTensorConversion pass, eliminating the dependence on a runtime
13// support library (other than for file I/O), and providing many more
14// opportunities for subsequent compiler optimization of the generated code.
15//
16//===----------------------------------------------------------------------===//
17
18#include "Utils/CodegenUtils.h"
20
32
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::sparse_tensor;
37
38//===----------------------------------------------------------------------===//
39// Helper methods.
40//===----------------------------------------------------------------------===//
41
42/// Flatten the given value ranges into a single vector of values.
45 for (const auto &vals : values)
46 llvm::append_range(result, vals);
47 return result;
48}
49
50/// Generates a load with proper `index` typing.
51static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
52 idx = genCast(builder, loc, idx, builder.getIndexType());
53 return memref::LoadOp::create(builder, loc, mem, idx);
54}
55
56/// Generates a store with proper `index` typing and proper value.
57static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
58 Value idx) {
59 idx = genCast(builder, loc, idx, builder.getIndexType());
60 val = genCast(builder, loc, val,
61 cast<ShapedType>(mem.getType()).getElementType());
62 memref::StoreOp::create(builder, loc, val, mem, idx);
63}
64
65/// Creates a straightforward counting for-loop.
66static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
68 Value lower = Value()) {
69 Type indexType = builder.getIndexType();
70 if (!lower)
71 lower = constantZero(builder, loc, indexType);
72 Value one = constantOne(builder, loc, indexType);
73 scf::ForOp forOp =
74 scf::ForOp::create(builder, loc, lower, upper, one, fields);
75 for (unsigned i = 0, e = fields.size(); i < e; i++)
76 fields[i] = forOp.getRegionIterArg(i);
77 builder.setInsertionPointToStart(forOp.getBody());
78 return forOp;
79}
80
81/// Creates a push back operation.
82static void createPushback(OpBuilder &builder, Location loc,
84 SparseTensorFieldKind kind, std::optional<Level> lvl,
85 Value value, Value repeat = Value()) {
86 Type etp = desc.getMemRefElementType(kind, lvl);
87 Value field = desc.getMemRefField(kind, lvl);
88 StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
89
90 auto pushBackOp = PushBackOp::create(
91 builder, loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl),
92 field, genCast(builder, loc, value, etp), repeat);
93
94 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
95 desc.setSpecifierField(builder, loc, specFieldKind, lvl,
96 pushBackOp.getNewSize());
97}
98
99/// Generates code that allocates a sparse storage scheme for given rank.
100static void allocSchemeForRank(OpBuilder &builder, Location loc,
101 MutSparseTensorDescriptor desc, Level startLvl) {
102 const SparseTensorType stt(desc.getRankedTensorType());
103 Value linear = constantIndex(builder, loc, 1);
104 const Level lvlRank = stt.getLvlRank();
105 for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
106 const auto lt = stt.getLvlType(lvl);
107 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
108 // Append linear x positions, initialized to zero. Since each compressed
109 // dimension initially already has a single zero entry, this maintains
110 // the desired "linear + 1" length property at all times. For loose
111 // compression, we multiply linear by two in order to append both the
112 // lo/hi positions.
113 Value posZero = constantZero(builder, loc, stt.getPosType());
114 if (isLooseCompressedLT(lt)) {
115 Value two = constantIndex(builder, loc, 2);
116 linear = arith::MulIOp::create(builder, loc, linear, two);
117 }
118 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
119 /*value=*/posZero, /*repeat=*/linear);
120 return;
121 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
122 return; // nothing to do
123 }
124 // Keep compounding the size, but nothing needs to be initialized
125 // at this level. We will eventually reach a compressed level or
126 // otherwise the values array for the from-here "all-dense" case.
127 assert(isDenseLT(lt));
128 Value size = desc.getLvlSize(builder, loc, lvl);
129 linear = arith::MulIOp::create(builder, loc, linear, size);
130 }
131 // Reached values array so prepare for an insertion.
132 Value valZero = constantZero(builder, loc, stt.getElementType());
134 std::nullopt, /*value=*/valZero, /*repeat=*/linear);
135}
136
137/// Creates allocation operation.
139 MemRefType memRefType, Value sz,
140 bool enableInit) {
141 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
142 Type elemType = memRefType.getElementType();
143 if (enableInit) {
144 Value fillValue = constantZero(builder, loc, elemType);
145 linalg::FillOp::create(builder, loc, fillValue, buffer);
146 }
147 return buffer;
148}
149
150/// Creates the dim sizes array, filling in from dynamic sizes.
151static void createDimSizes(OpBuilder &builder, Location loc,
152 SparseTensorType stt, ValueRange dynSizes,
153 /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
154 const Dimension dimRank = stt.getDimRank();
155 dimSizesValues.clear();
156 dimSizesValues.reserve(dimRank);
157 unsigned i = 0;
158 for (const Size sz : stt.getDimShape())
159 dimSizesValues.push_back(ShapedType::isDynamic(sz)
160 ? dynSizes[i++]
161 : constantIndex(builder, loc, sz));
162}
163
164/// Creates allocation for each field in sparse tensor type. Note that
165/// for all dynamic memrefs in the sparse tensor stroage layout, the
166/// memory size is really the capacity of the "vector", while the actual
167/// size resides in the sizes array.
168static void createAllocFields(OpBuilder &builder, Location loc,
169 SparseTensorType stt, bool enableInit,
170 Value sizeHint,
171 SmallVectorImpl<Value> &lvlSizesValues,
172 /*out*/ SmallVectorImpl<Value> &fields) {
173 Level lvlRank = stt.getLvlRank();
174 // Set up some heuristic sizes. We try to set the initial
175 // size based on available information. Otherwise we just
176 // initialize a few elements to start the reallocation chain.
177 // TODO: refine this
178 Value posHeuristic, crdHeuristic, valHeuristic;
179 if (stt.isAllDense()) {
180 valHeuristic = lvlSizesValues[0];
181 for (Level lvl = 1; lvl < lvlRank; lvl++)
182 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
183 lvlSizesValues[lvl]);
184 } else if (sizeHint) {
185 if (stt.getAoSCOOStart() == 0) {
186 posHeuristic = constantIndex(builder, loc, 2);
187 crdHeuristic = arith::MulIOp::create(
188 builder, loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
189 } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
190 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
191 constantIndex(builder, loc, 1));
192 crdHeuristic = sizeHint;
193 } else {
194 posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
195 }
196 valHeuristic = sizeHint;
197 } else {
198 posHeuristic = crdHeuristic = valHeuristic =
199 constantIndex(builder, loc, 16);
200 }
201 // Initializes all fields. An initial storage specifier and allocated
202 // positions/coordinates/values memrefs (with heuristic capacity).
204 stt,
205 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
206 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
207 Level /*lvl*/, LevelType /*lt*/) -> bool {
208 assert(fields.size() == fIdx);
209 Value field;
210 switch (fKind) {
212 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
213 break;
215 field = createAllocation(builder, loc, cast<MemRefType>(fType),
216 posHeuristic, enableInit);
217 break;
219 field = createAllocation(builder, loc, cast<MemRefType>(fType),
220 crdHeuristic, enableInit);
221 break;
223 field = createAllocation(builder, loc, cast<MemRefType>(fType),
224 valHeuristic, enableInit);
225 break;
226 }
227 assert(field);
228 fields.push_back(field);
229 // Returns true to continue the iteration.
230 return true;
231 });
232 // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
233 // and gives all position fields an initial zero entry, so that it is
234 // easier to maintain the "linear + 1" length property.
235 MutSparseTensorDescriptor desc(stt, fields);
236 Value posZero = constantZero(builder, loc, stt.getPosType());
237 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
238 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
239 const auto lt = stt.getLvlType(lvl);
240 if (isCompressedLT(lt) || isLooseCompressedLT(lt))
241 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
242 /*value=*/posZero);
243 }
244 allocSchemeForRank(builder, loc, desc, /*rank=*/0);
245}
246
247/// Helper method that generates block specific to compressed case:
248///
249/// // given: parentPos = posCursor[lvl-1]
250/// pstart = desc.positions[lvl][parentPos]
251/// pstop = desc.positions[lvl][parentPos+1]
252/// plast = pstop - 1
253/// msz = desc.coordinates[lvl].size()
254/// if (pstart < pstop) {
255/// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
256/// } else { // first insertion
257/// isPresent = false
258/// desc.positions[lvl][parentPos] = msz
259/// }
260/// if (isPresent) { // coordinate is already present
261/// pnext = plast
262/// } else {
263/// desc.coordinates[lvl].push_back(lvlCoords[lvl])
264/// desc.positions[lvl][parentPos+1] = msz+1
265/// pnext = msz
266/// <prepare level lvl+1>
267/// }
268/// posCursor[lvl] = pnext
271 Value /*unused*/, Value parentPos, Level lvl) {
272 const SparseTensorType stt(desc.getRankedTensorType());
273 const Level lvlRank = stt.getLvlRank();
274 assert(lvl < lvlRank && "Level is out of bounds");
275 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
276 "Level-rank mismatch");
277 SmallVector<Type> types;
278 Type indexType = builder.getIndexType();
279 Type boolType = builder.getIntegerType(1);
280 unsigned crdFidx;
281 unsigned crdStride;
282 std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
283 const Value one = constantIndex(builder, loc, 1);
284 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
285 const Value positionsAtLvl = desc.getPosMemRef(lvl);
286 const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
287 const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
288 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
289 const Value crdStrideC =
290 crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
291 const Value msz =
292 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
293 : crdMsz;
294 const Value plast = arith::SubIOp::create(
295 builder, loc, genCast(builder, loc, pstop, indexType), one);
296 // Conditional expression.
297 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
298 pstart, pstop);
299 types.push_back(boolType);
300 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt, /*else*/ true);
301 types.pop_back();
302 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
303 Value crd = genLoad(
304 builder, loc, desc.getMemRefField(crdFidx),
305 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
306 : plast);
307 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
308 genCast(builder, loc, crd, indexType),
309 lvlCoords[lvl]);
310 scf::YieldOp::create(builder, loc, eq);
311 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
312 if (lvl > 0)
313 genStore(builder, loc, msz, positionsAtLvl, parentPos);
314 scf::YieldOp::create(builder, loc, constantI1(builder, loc, false));
315 builder.setInsertionPointAfter(ifOp1);
316 // If present construct. Note that for a non-unique dimension level, we
317 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
318 //
319 // TODO: generate less temporary IR?
320 //
321 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
322 types.push_back(desc.getField(i).getType());
323 types.push_back(indexType);
324 const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
325 : constantI1(builder, loc, false);
326 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p, /*else*/ true);
327 // If present (fields unaffected, update pnext to plast).
328 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
329
330 // FIXME: This does not looks like a clean way, but probably the most
331 // efficient way.
332 desc.getFields().push_back(plast);
333 scf::YieldOp::create(builder, loc, desc.getFields());
334 desc.getFields().pop_back();
335
336 // If !present (changes fields, update pnext).
337 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
338 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
339 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
340 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
341 /*value=*/lvlCoords[lvl]);
342 // Prepare the next level "as needed".
343 if ((lvl + 1) < lvlRank)
344 allocSchemeForRank(builder, loc, desc, lvl + 1);
345
346 desc.getFields().push_back(msz);
347 scf::YieldOp::create(builder, loc, desc.getFields());
348 desc.getFields().pop_back();
349
350 // Update fields and return next pos.
351 builder.setInsertionPointAfter(ifOp2);
352 unsigned o = 0;
353 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
354 desc.setField(i, ifOp2.getResult(o++));
355 return ifOp2.getResult(o);
356}
357
358/// Generates insertion finalization code.
359static void genEndInsert(OpBuilder &builder, Location loc,
361 const SparseTensorType stt(desc.getRankedTensorType());
362 const Level lvlRank = stt.getLvlRank();
363 for (Level lvl = 0; lvl < lvlRank; lvl++) {
364 const auto lt = stt.getLvlType(lvl);
365 if (isCompressedLT(lt)) {
366 // Compressed dimensions need a position cleanup for all entries
367 // that were not visited during the insertion pass.
368 //
369 // TODO: avoid cleanup and keep compressed scheme consistent at all
370 // times?
371 //
372 if (lvl > 0) {
373 Type posType = stt.getPosType();
374 Value posMemRef = desc.getPosMemRef(lvl);
375 Value hi = desc.getPosMemSize(builder, loc, lvl);
376 Value zero = constantIndex(builder, loc, 0);
377 Value one = constantIndex(builder, loc, 1);
378 // Vector of only one, but needed by createFor's prototype.
379 SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
380 scf::ForOp loop = createFor(builder, loc, hi, inits, one);
381 Value i = loop.getInductionVar();
382 Value oldv = loop.getRegionIterArg(0);
383 Value newv = genLoad(builder, loc, posMemRef, i);
384 Value posZero = constantZero(builder, loc, posType);
385 Value cond = arith::CmpIOp::create(
386 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
387 scf::IfOp ifOp = scf::IfOp::create(builder, loc, TypeRange(posType),
388 cond, /*else*/ true);
389 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
390 genStore(builder, loc, oldv, posMemRef, i);
391 scf::YieldOp::create(builder, loc, oldv);
392 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
393 scf::YieldOp::create(builder, loc, newv);
394 builder.setInsertionPointAfter(ifOp);
395 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
396 builder.setInsertionPointAfter(loop);
397 }
398 } else {
399 assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
400 isNOutOfMLT(lt));
401 }
402 }
403}
404
405/// Generates a subview into the sizes.
406static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
407 Value sz) {
408 auto memTp = llvm::cast<MemRefType>(mem.getType());
409 // For higher-dimensional memrefs, we assume that the innermost
410 // dimension is always of the right size.
411 // TODO: generate complex truncating view here too?
412 if (memTp.getRank() > 1)
413 return mem;
414 // Truncate linear memrefs to given size.
415 return memref::SubViewOp::create(
416 builder, loc,
417 MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
418 mem, ValueRange{}, ValueRange{sz}, ValueRange{},
419 ArrayRef<int64_t>{0}, // static offset
420 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
421 ArrayRef<int64_t>{1}) // static stride
422 .getResult();
423}
424
425/// Creates the reassociation array.
427getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
428 SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
429 // Create reassociation in the form:
430 // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
431 for (unsigned i = 0; i < batchLvls; i++)
432 ret[i].push_back(i);
433
434 for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
435 ret.back().push_back(i);
436
437 return ret;
438}
439
440//===----------------------------------------------------------------------===//
441// Codegen rules.
442//===----------------------------------------------------------------------===//
443
444namespace {
445
446/// Helper class to help lowering sparse_tensor.insert operation.
447class SparseInsertGenerator
448 : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
449public:
450 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
451 bool genCall)
452 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
453
454 /// Generates code along an insertion path without the need for a "cursor".
455 /// This current insertion strategy comes at the expense of some testing
456 /// overhead for each insertion. The strategy will be optimized later for
457 /// common insertion patterns. The current insertion strategy also assumes
458 /// insertions occur in "a reasonable order" that enables building the
459 /// storage scheme in an appending/inserting kind of fashion (i.e. no
460 /// in-between insertions that need data movement). The implementation
461 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
462 ///
463 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
464 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
465 OpBuilder &builder, Location loc) {
466 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
467 const Level lvlRank = stt.getLvlRank();
468 // Extract fields and coordinates from args.
469 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
470 MutSparseTensorDescriptor desc(stt, fields);
471 const SmallVector<Value> coords =
472 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
473 Value value = args.back();
474 Value parentPos = constantZero(builder, loc, builder.getIndexType());
475 // Generate code for every level.
476 for (Level lvl = 0; lvl < lvlRank; lvl++) {
477 const auto lt = stt.getLvlType(lvl);
478 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
479 // Create:
480 // if (!present) {
481 // coordinates[lvl].push_back(coords[lvl])
482 // <update positions and prepare level lvl + 1>
483 // }
484 // positions[lvl] = coordinates.size() - 1
485 // <insert @ positions[lvl] at next level lvl + 1>
486 if (isLooseCompressedLT(lt)) {
487 Value two = constantIndex(builder, loc, 2);
488 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
489 }
490 parentPos =
491 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
492 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
493 // Create:
494 // coordinates[lvl].push_back(coords[lvl])
495 // positions[lvl] = positions[lvl-1]
496 // <insert @ positions[lvl] at next level lvl + 1>
497 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
498 lvl, /*value=*/coords[lvl]);
499 } else {
500 assert(isDenseLT(lt));
501 // Construct the new position as:
502 // positions[lvl] = size * positions[lvl-1] + coords[lvl]
503 // <insert @ positions[lvl] at next level lvl + 1>
504 Value size = desc.getLvlSize(builder, loc, lvl);
505 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
506 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
507 }
508 }
509 // Reached the actual value append/insert.
510 if (!stt.isDenseLvl(lvlRank - 1))
511 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
512 std::nullopt, value);
513 else
514 genStore(builder, loc, value, desc.getValMemRef(), parentPos);
515 return fields;
516 }
517
518 std::string getMangledFuncName() {
519 // The mangled name of the function has this format:
520 // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
521 constexpr const char kInsertFuncNamePrefix[] = "_insert_";
522 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
523 SmallString<32> nameBuffer;
524 llvm::raw_svector_ostream nameOstream(nameBuffer);
525 nameOstream << kInsertFuncNamePrefix;
526 const Level lvlRank = stt.getLvlRank();
527 for (Level l = 0; l < lvlRank; l++) {
528 std::string lvlType = toMLIRString(stt.getLvlType(l));
529 // Replace/remove punctuations in level properties.
530 std::replace_if(
531 lvlType.begin(), lvlType.end(),
532 [](char c) { return c == '(' || c == ','; }, '_');
533 llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
534 nameOstream << lvlType << "_";
535 }
536 // Static dim sizes are used in the generated code while dynamic sizes are
537 // loaded from the dimSizes buffer. This is the reason for adding the shape
538 // to the function name.
539 for (const auto sz : stt.getDimShape())
540 nameOstream << sz << "_";
541 // Permutation information is also used in generating insertion.
542 if (!stt.isIdentity())
543 nameOstream << stt.getDimToLvl() << "_";
544 nameOstream << stt.getElementType() << "_";
545 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
546 return nameOstream.str().str();
547 }
548
549private:
550 TensorType rtp;
551};
552
553/// Sparse tensor storage conversion rule for returns.
554class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
555public:
556 using OpConversionPattern::OpConversionPattern;
557 LogicalResult
558 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const override {
560 // Create a return with the flattened value extracted from sparse tensors.
561 rewriter.replaceOpWithNewOp<func::ReturnOp>(
562 op, flattenValues(adaptor.getOperands()));
563 return success();
564 }
565};
566
567/// Sparse tensor storage conversion rule for calls.
568class SparseCallConverter : public OpConversionPattern<func::CallOp> {
569public:
570 // The default CallOp converter can not handle 1:N type conversion.
571 using OpConversionPattern::OpConversionPattern;
572 LogicalResult
573 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter) const override {
575 Location loc = op.getLoc();
576 // In case of:
577 // sparse_tensor, f, sparse_tensor = call @foo(...)
578 // ==>
579 // memref..., f, memref = call @foo(...) replace with
580 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
581 SmallVector<Type> finalRetTy;
582 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
583 return failure();
584
585 // (1) Generates new call with flattened return value.
586 auto newCall =
587 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
588 flattenValues(adaptor.getOperands()));
589 // (2) Gather sparse tensor returns.
590 SmallVector<SmallVector<Value>> packedResultVals;
591 // Tracks the offset of current return value (of the original call)
592 // relative to the new call (after sparse tensor flattening);
593 unsigned retOffset = 0;
594 // Temporal buffer to hold the flattened list of type for
595 // a sparse tensor.
596 SmallVector<Type> sparseFlat;
597 for (auto ret : op.getResults()) {
598 assert(retOffset < newCall.getNumResults());
599 auto retType = ret.getType();
600 if (failed(typeConverter->convertType(retType, sparseFlat)))
601 llvm_unreachable("Failed to convert type in sparse tensor codegen");
602
603 // Converted types can not be empty when the type conversion succeed.
604 assert(!sparseFlat.empty());
605 if (sparseFlat.size() > 1) {
606 auto flatSize = sparseFlat.size();
607 packedResultVals.emplace_back();
608 llvm::append_range(packedResultVals.back(),
609 newCall.getResults().slice(retOffset, flatSize));
610 retOffset += flatSize;
611 } else {
612 // If this is an 1:1 conversion, no need for casting.
613 packedResultVals.emplace_back();
614 packedResultVals.back().push_back(newCall.getResult(retOffset));
615 retOffset++;
616 }
617 sparseFlat.clear();
618 }
619
620 assert(packedResultVals.size() == op.getNumResults());
621 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
622 return success();
623 }
624};
625
626/// Sparse codegen rule for level accesses.
627class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
628public:
629 using OpConversionPattern::OpConversionPattern;
630 LogicalResult
631 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter) const override {
633 std::optional<int64_t> lvl = op.getConstantLvlIndex();
634 RankedTensorType srcType = op.getSource().getType();
635 if (!lvl || !getSparseTensorEncoding(srcType))
636 return failure();
637
638 auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
639 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
640
641 rewriter.replaceOp(op, sz);
642 return success();
643 }
644};
645
646// TODO: use a new SortCOO operation here instead of reusing convert op.
647struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
648 using OpConversionPattern::OpConversionPattern;
649 LogicalResult
650 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
651 ConversionPatternRewriter &rewriter) const override {
652 Location loc = op.getLoc();
653 MLIRContext *ctx = op.getContext();
654
655 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
656 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
657
658 // Should have been verified.
659 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
660 dstStt.isCOOType() && srcStt.isCOOType());
661 assert(dstStt.hasSameDimToLvl(srcStt));
662
663 // We don't need a mutable descriptor here as we perform sorting in-place.
664 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
665 op.getInputCoo().getType());
666 auto nnz = desc.getValMemSize(rewriter, op.getLoc());
667 auto crd = desc.getAOSMemRef();
668 auto val = desc.getValMemRef();
669
670 // Otherwise we need another data shuffle and a non-identity map.
671 assert(dstStt.hasSameDimToLvl(srcStt));
672 (void)dstStt; // to silence warning when assertion is disabled
673
674 auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
675
676 SortOp::create(rewriter, loc, nnz, crd, ValueRange{val}, id,
677 rewriter.getIndexAttr(0), op.getAlgorithm());
678
679 // Since we do in-place sorting, the destinate tensor will have the same set
680 // of memrefs as the source tensor.
681 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
682 return success();
683 }
684};
685
686template <typename Op, StorageSpecifierKind kind>
687class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
688public:
689 using OpConversionPattern<Op>::OpConversionPattern;
690 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
691
692 LogicalResult
693 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter) const override {
695 // Simply lowers to specifer.get <field> operation.
696 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
697 op.getSlice().getType());
698 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
699 op.getDim().getZExtValue());
700
701 rewriter.replaceOp(op, v);
702 return success();
703 }
704};
705
706/// Sparse codegen rule for trivial tensor casts.
707class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
708public:
709 using OpConversionPattern::OpConversionPattern;
710 LogicalResult
711 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter) const override {
713 // Only rewrite identically annotated source/dest.
714 auto encDst = getSparseTensorEncoding(op.getType());
715 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
716 if (!encDst || encDst != encSrc)
717 return failure();
718 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
719 return success();
720 }
721};
722
723class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
724public:
725 using OpConversionPattern::OpConversionPattern;
726 LogicalResult
727 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
728 ConversionPatternRewriter &rewriter) const override {
729 // Simply fold the operation.
730 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
731 return success();
732 }
733};
734
735/// Sparse codegen rule for the alloc operator.
736class SparseTensorAllocConverter
737 : public OpConversionPattern<bufferization::AllocTensorOp> {
738public:
739 using OpConversionPattern::OpConversionPattern;
740 SparseTensorAllocConverter(const TypeConverter &typeConverter,
741 MLIRContext *context, bool enableInit)
742 : OpConversionPattern(typeConverter, context),
743 enableBufferInitialization(enableInit) {}
744
745 LogicalResult
746 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 const auto resType = getSparseTensorType(op);
749 if (!resType.hasEncoding())
750 return failure();
751
752 Location loc = op.getLoc();
753 // Deal with copy.
754 if (op.getCopy()) {
756 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
757 SmallVector<Value> fields;
758 fields.reserve(desc.getNumFields());
759 // Memcpy on memref fields.
760 for (auto field : desc.getMemRefFields()) {
761 auto memrefTp = cast<MemRefType>(field.getType());
762 auto size = memref::DimOp::create(rewriter, loc, field, 0);
763 auto copied =
764 memref::AllocOp::create(rewriter, loc, memrefTp, ValueRange{size});
765 memref::CopyOp::create(rewriter, loc, field, copied);
766 fields.push_back(copied);
767 }
768 // Reuses specifier.
769 fields.push_back(desc.getSpecifier());
770 assert(fields.size() == desc.getNumFields());
771 rewriter.replaceOpWithMultiple(op, {fields});
772 return success();
773 }
774
775 if (!resType.isIdentity()) {
776 return rewriter.notifyMatchFailure(
777 op, "try run --sparse-reinterpret-map before codegen");
778 }
779 // Level size equals to dimension size since lvl2dim map is an identity map.
780 SmallVector<Value> lvlSizesValues;
781 createDimSizes(rewriter, loc, resType,
782 flattenValues(adaptor.getDynamicSizes()),
783 /*dimSizesValues=*/lvlSizesValues);
784
785 // Construct allocation for each field.
786 Value sizeHint = op.getSizeHint();
787 SmallVector<Value> fields;
788 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
789 sizeHint, lvlSizesValues, fields);
790
791 // Replace operation with resulting memrefs.
792 rewriter.replaceOpWithMultiple(op, {fields});
793 return success();
794 }
795
796private:
797 bool enableBufferInitialization;
798};
799
800/// Sparse codegen rule for the empty tensor operator.
801class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
802public:
803 using OpConversionPattern::OpConversionPattern;
804 SparseTensorEmptyConverter(const TypeConverter &typeConverter,
805 MLIRContext *context, bool enableInit)
806 : OpConversionPattern(typeConverter, context),
807 enableBufferInitialization(enableInit) {}
808
809 LogicalResult
810 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
811 ConversionPatternRewriter &rewriter) const override {
812 const auto resType = getSparseTensorType(op);
813 if (!resType.hasEncoding())
814 return failure();
815
816 if (!resType.isIdentity()) {
817 return rewriter.notifyMatchFailure(
818 op, "try run --sparse-reinterpret-map before codegen");
819 }
820
821 Location loc = op.getLoc();
822 // Level size equals to dimension size since lvl2dim map is an identity map.
823 SmallVector<Value> lvlSizesValues;
824 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
825 /*dimSizesValues=*/lvlSizesValues);
826 // Construct allocation for each field.
827 Value sizeHint; // none
828 SmallVector<Value> fields;
829 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
830 sizeHint, lvlSizesValues, fields);
831
832 // Replace operation with resulting memrefs.
833 rewriter.replaceOpWithMultiple(op, {fields});
834 return success();
835 }
836
837private:
838 bool enableBufferInitialization;
839};
840
841/// Sparse codegen rule for the dealloc operator.
842class SparseTensorDeallocConverter
843 : public OpConversionPattern<bufferization::DeallocTensorOp> {
844public:
845 using OpConversionPattern::OpConversionPattern;
846 SparseTensorDeallocConverter(const TypeConverter &typeConverter,
847 MLIRContext *context, bool createDeallocs)
848 : OpConversionPattern(typeConverter, context),
849 createDeallocs(createDeallocs) {}
850
851 LogicalResult
852 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
853 ConversionPatternRewriter &rewriter) const override {
854 auto enc = getSparseTensorEncoding(op.getTensor().getType());
855 if (!enc)
856 return failure();
857
858 // If user requests not to deallocate sparse tensors, simply erase the
859 // operation.
860 if (createDeallocs) {
861 // Replace the sparse tensor deallocation with field deallocations.
862 Location loc = op.getLoc();
864 adaptor.getTensor(),
865 cast<RankedTensorType>(op.getTensor().getType()));
866 for (auto input : desc.getMemRefFields())
867 // Deallocate every buffer used to store the sparse tensor handler.
868 memref::DeallocOp::create(rewriter, loc, input);
869 }
870 rewriter.eraseOp(op);
871 return success();
872 }
873
874private:
875 const bool createDeallocs;
876};
877
878/// Sparse codegen rule for tensor rematerialization.
879class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
880public:
881 using OpConversionPattern::OpConversionPattern;
882 LogicalResult
883 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
884 ConversionPatternRewriter &rewriter) const override {
885 // Prepare descriptor.
886 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
887 op.getTensor().getType());
888 // Generate optional insertion finalization code.
889 if (op.getHasInserts())
890 genEndInsert(rewriter, op.getLoc(), desc);
891 // Replace operation with resulting memrefs.
892 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
893 return success();
894 }
895};
896
897/// Sparse codegen rule for the expand op.
898class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
899public:
900 using OpConversionPattern::OpConversionPattern;
901 LogicalResult
902 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
903 ConversionPatternRewriter &rewriter) const override {
904 if (!getSparseTensorEncoding(op.getTensor().getType()))
905 return failure();
906 Location loc = op->getLoc();
907 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
908 op.getTensor().getType());
909 const auto srcType = getSparseTensorType(op.getTensor());
910 Type eltType = srcType.getElementType();
911 Type boolType = rewriter.getIntegerType(1);
912 Type idxType = rewriter.getIndexType();
913 // All initialization should be done on entry of the loop nest.
914 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
915
916 // Determine the size for access expansion (always the innermost stored
917 // level size).
918 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
919 // Generate a memref for `sz` elements of type `t`.
920 const auto genAlloc = [&](Type t) {
921 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
922 return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz});
923 };
924 // Allocate temporary buffers for values/filled-switch and added.
925 // We do not use stack buffers for this, since the expanded size may
926 // be rather large (as it envelops a single expanded dense dimension).
927 Value values = genAlloc(eltType);
928 Value filled = genAlloc(boolType);
929 Value added = genAlloc(idxType);
930 Value zero = constantZero(rewriter, loc, idxType);
931 // Reset the values/filled-switch to all-zero/false. Note that this
932 // introduces an O(N) operation into the computation, but this reset
933 // operation is amortized over the innermost loops for the access
934 // pattern expansion. As noted in the operation doc, we would like
935 // to amortize this setup cost even between kernels.
936 linalg::FillOp::create(rewriter, loc,
937 ValueRange{constantZero(rewriter, loc, eltType)},
938 ValueRange{values});
939 linalg::FillOp::create(rewriter, loc,
940 ValueRange{constantZero(rewriter, loc, boolType)},
941 ValueRange{filled});
942 // Replace expansion op with these buffers and initial coordinate.
943 assert(op.getNumResults() == 4);
944 rewriter.replaceOp(op, {values, filled, added, zero});
945 return success();
946 }
947};
948
949/// Sparse codegen rule for the compress operator.
950class SparseCompressConverter : public OpConversionPattern<CompressOp> {
951public:
952 using OpConversionPattern::OpConversionPattern;
953 LogicalResult
954 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
955 ConversionPatternRewriter &rewriter) const override {
956 Location loc = op->getLoc();
957 SmallVector<Value> fields;
958 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
959 op.getTensor().getType());
960 Value values = llvm::getSingleElement(adaptor.getValues());
961 Value filled = llvm::getSingleElement(adaptor.getFilled());
962 Value added = llvm::getSingleElement(adaptor.getAdded());
963 Value count = llvm::getSingleElement(adaptor.getCount());
964 const SparseTensorType dstType(desc.getRankedTensorType());
965 Type eltType = dstType.getElementType();
966
967 // If the innermost level is ordered, we need to sort the coordinates
968 // in the "added" array prior to applying the compression.
969 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
970 SortOp::create(rewriter, loc, count, added, ValueRange{},
971 rewriter.getMultiDimIdentityMap(1),
972 rewriter.getIndexAttr(0),
973 SparseTensorSortKind::HybridQuickSort);
974 // While performing the insertions, we also need to reset the elements
975 // of the values/filled-switch by only iterating over the set elements,
976 // to ensure that the runtime complexity remains proportional to the
977 // sparsity of the expanded access pattern.
978 //
979 // Generate
980 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
981 // crd = added[i];
982 // value = values[crd];
983 // insert({lvlCoords, crd}, value);
984 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
985 // values[crd] = 0;
986 // filled[crd] = false;
987 // yield new_memrefs
988 // }
989 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
990 Value i = loop.getInductionVar();
991
992 Value crd = genLoad(rewriter, loc, added, i);
993 Value value = genLoad(rewriter, loc, values, crd);
994 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
995 SmallVector<Type> flatSpTensorTps = llvm::to_vector(
996 llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
997 SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
998 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
999 params.push_back(crd);
1000 params.push_back(value);
1001 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1002 params, /*genCall=*/true);
1003 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1004 genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1005 genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1006 scf::YieldOp::create(rewriter, loc, insertRet);
1007
1008 rewriter.setInsertionPointAfter(loop);
1009 // Deallocate the buffers on exit of the full loop nest.
1010 Operation *parent = getTop(op);
1011 rewriter.setInsertionPointAfter(parent);
1012 memref::DeallocOp::create(rewriter, loc, values);
1013 memref::DeallocOp::create(rewriter, loc, filled);
1014 memref::DeallocOp::create(rewriter, loc, added);
1015 // Replace operation with resulting memrefs.
1016 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1017 return success();
1018 }
1019};
1020
1021/// Sparse codegen rule for the insert operator.
1022class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1023public:
1024 using OpConversionPattern::OpConversionPattern;
1025 LogicalResult
1026 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1027 ConversionPatternRewriter &rewriter) const override {
1028 auto stt = getSparseTensorType(op.getDest());
1029 if (!stt.hasEncoding())
1030 return failure();
1031 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1032
1033 Location loc = op.getLoc();
1034 auto desc =
1035 getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1036 TypeRange flatSpTensorTps = desc.getFields().getTypes();
1037 SmallVector<Value> params = llvm::to_vector(desc.getFields());
1038 SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1039 params.append(flatIndices.begin(), flatIndices.end());
1040 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1041 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1042 params, /*genCall=*/true);
1043 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1044 // Replace operation with resulting memrefs.
1045 rewriter.replaceOpWithMultiple(op, {ret});
1046 return success();
1047 }
1048};
1049
1050/// Sparse codegen rule for position accesses.
1051class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1052public:
1053 using OpAdaptor = ToPositionsOp::Adaptor;
1054 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1055 LogicalResult
1056 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter) const override {
1058 // Replace the requested position access with corresponding field.
1059 // The view is restricted to the actual size to ensure clients
1060 // of this operation truly observe size, not capacity!
1061 Location loc = op.getLoc();
1062 Level lvl = op.getLevel();
1063 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1064 op.getTensor().getType());
1065 auto mem = desc.getPosMemRef(lvl);
1066 auto size = desc.getPosMemSize(rewriter, loc, lvl);
1067 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1068 return success();
1069 }
1070};
1071
1072/// Sparse codegen rule for accessing the coordinates arrays.
1073class SparseToCoordinatesConverter
1074 : public OpConversionPattern<ToCoordinatesOp> {
1075public:
1076 using OpAdaptor = ToCoordinatesOp::Adaptor;
1077 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1078 LogicalResult
1079 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1080 ConversionPatternRewriter &rewriter) const override {
1081 // Replace the requested coordinates access with corresponding field.
1082 // The view is restricted to the actual size to ensure clients
1083 // of this operation truly observe size, not capacity!
1084 Location loc = op.getLoc();
1085 Level lvl = op.getLevel();
1086 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1087 op.getTensor().getType());
1088 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1089 if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1090 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1091 mem = genSliceToSize(rewriter, loc, mem, size);
1092 }
1093 rewriter.replaceOp(op, mem);
1094 return success();
1095 }
1096};
1097
1098/// Sparse codegen rule for accessing the linear coordinates buffer.
1099class SparseToCoordinatesBufferConverter
1100 : public OpConversionPattern<ToCoordinatesBufferOp> {
1101public:
1102 using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
1103 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1104 LogicalResult
1105 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1106 ConversionPatternRewriter &rewriter) const override {
1107 // Replace the requested coordinates access with corresponding field.
1108 // The view is restricted to the actual size to ensure clients
1109 // of this operation truly observe size, not capacity!
1110 Location loc = op.getLoc();
1111 Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1112 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1113 op.getTensor().getType());
1114 auto mem = desc.getAOSMemRef();
1115 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1116 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1117 return success();
1118 }
1119};
1120
1121/// Sparse codegen rule for value accesses.
1122class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1123public:
1124 using OpAdaptor = ToValuesOp::Adaptor;
1125 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1126 LogicalResult
1127 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1128 ConversionPatternRewriter &rewriter) const override {
1129 // Replace the requested values access with corresponding field.
1130 // The view is restricted to the actual size to ensure clients
1131 // of this operation truly observe size, not capacity!
1132 Location loc = op.getLoc();
1133 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1134 op.getTensor().getType());
1135 auto mem = desc.getValMemRef();
1136 auto size = desc.getValMemSize(rewriter, loc);
1137 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1138 return success();
1139 }
1140};
1141
1142/// Sparse codegen rule for the convert operator.
1143class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1144public:
1145 using OpConversionPattern::OpConversionPattern;
1146 LogicalResult
1147 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const override {
1149 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1150 SparseTensorEncodingAttr encSrc =
1151 getSparseTensorEncoding(op.getSource().getType());
1152 // The output tensor can not be a slice and those cases should have been
1153 // rejected by ConvertOp::verify() already.
1154 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1155 // Different encoding (except for different bitwidth) should be handled by
1156 // rewriting.
1157 // We need further rewrites if the input tensor is a slice too.
1158 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1159 encSrc.isSlice()) {
1160 return failure();
1161 }
1162
1163 Type retElemTp = op.getResult().getType().getElementType();
1164 Type srcElemTp = op.getSource().getType().getElementType();
1165 // Fold the trivial cases.
1166 if (retElemTp == srcElemTp && encDst == encSrc) {
1167 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1168 return success();
1169 }
1170 //
1171 // Do element-wise type conversion without using InsertOp.
1172 //
1173 // for each memref in srcTensor:
1174 // dst = memref.alloc
1175 // if srcMemRefType != dstMemRefType:
1176 // for every dst[i] = cast(src[i])
1177 // else:
1178 // dst = memref.copy(src)
1179 Location loc = op.getLoc();
1180 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1181 op.getSource().getType());
1182 SmallVector<Value> fields;
1184 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1185 [&rewriter, &fields, srcDesc,
1186 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1187 LevelType /*lt*/) -> bool {
1188 // Simply reuses the storage specifier as it is an SSA value.
1189 if (fKind == SparseTensorFieldKind::StorageSpec) {
1190 fields.push_back(srcDesc.getSpecifier());
1191 } else {
1192 // Allocates new memrefs
1193 Value srcMem = srcDesc.getMemRefField(fIdx);
1194 // TODO: We can instead use the actual memSize in specifier, that
1195 // would require a subViewOp to avoid overflow when copying
1196 // values.
1197 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1198 auto dstMem = memref::AllocOp::create(rewriter, loc,
1199 cast<MemRefType>(fTp), sz);
1200 if (fTp != srcMem.getType()) {
1201 // Converts elements type.
1202 scf::buildLoopNest(
1203 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1204 constantIndex(rewriter, loc, 1),
1205 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1206 ValueRange ivs) {
1207 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1208 Value casted = genCast(builder, loc, v,
1209 dstMem.getType().getElementType());
1210 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1211 });
1212 } else {
1213 // TODO: We can even reuse the same memref for the new tensor,
1214 // but that requires a `ref-counting` based memory management
1215 // for shared memrefs between multiple sparse tensors.
1216 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1217 }
1218 fields.push_back(dstMem);
1219 }
1220 return true;
1221 });
1222
1223 rewriter.replaceOpWithMultiple(op, {fields});
1224 return success();
1225 }
1226};
1227
1228class SparseExtractSliceConverter
1229 : public OpConversionPattern<tensor::ExtractSliceOp> {
1230public:
1231 using OpConversionPattern::OpConversionPattern;
1232 LogicalResult
1233 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1234 ConversionPatternRewriter &rewriter) const override {
1235 Location loc = op.getLoc();
1236 MLIRContext *ctx = op.getContext();
1237 auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1238 auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1239 // TODO: We should check these in ExtractSliceOp::verify.
1240 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1241 return failure();
1242 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1243
1244 SmallVector<Value> fields;
1245 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1246 op.getSource().getType());
1247
1248 auto newSpec = StorageSpecifierInitOp::create(
1249 rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1250 desc.getSpecifier());
1251 desc.setSpecifier(newSpec);
1252
1253 // Fills in slice information.
1254 for (auto [idx, offset, size, stride] : llvm::enumerate(
1255 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1256 Dimension dim = idx;
1257
1258 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1259 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1260 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1261 // TODO: We could probably only set dynamic value here. But it would
1262 // requires us to fill the hole when casting a static slice to dynamic
1263 // slice.
1264 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1265 dim, offsetV);
1266
1267 // FIXME: we need to distinguish level sizes and dimension size for slices
1268 // here. Maybe we should store slice level sizes in a different array
1269 // instead of reusing it.
1270 assert(srcEnc.isIdentity());
1271 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1272 sizeV);
1273 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1274 dim, strideV);
1275 }
1276
1277 // NOTE: we can not generate tuples directly from descriptor here, as the
1278 // descriptor is holding the original type, yet we want the slice type
1279 // here (they shared every memref but with an updated specifier).
1280 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1281 return success();
1282 }
1283};
1284
1285/// Sparse codegen rule for number of entries operator.
1286class SparseNumberOfEntriesConverter
1287 : public OpConversionPattern<NumberOfEntriesOp> {
1288public:
1289 using OpConversionPattern::OpConversionPattern;
1290 LogicalResult
1291 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1292 ConversionPatternRewriter &rewriter) const override {
1293 // Query memSizes for the actually stored values.
1294 // FIXME: the nse value computed in this way might be wrong when there is
1295 // any "loose_compressed" level.
1296 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1297 op.getTensor().getType());
1298 rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1299 return success();
1300 }
1301};
1302
1303struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1304 using OpConversionPattern::OpConversionPattern;
1305 LogicalResult
1306 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1307 ConversionPatternRewriter &rewriter) const override {
1308 Location loc = op.getLoc();
1309 const auto stt = getSparseTensorType(op.getResult());
1310
1311 SmallVector<Value> fields;
1312
1314 stt,
1315 [&rewriter, &fields, &op, &stt,
1316 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1317 Level /*lvl*/, LevelType lt) -> bool {
1318 assert(fields.size() == fIdx);
1319 if (fKind == SparseTensorFieldKind::StorageSpec) {
1320 fields.push_back(
1321 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1322 } else {
1323 // Else simply takes the inputs.
1324 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1325 ? op.getValues()
1326 : op.getLevels()[fIdx];
1327 // TODO: handle batch.
1328 TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1329 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1330 // Flattens the buffer to batchLvlRank.
1331 auto reassoc = getReassociationForFlattening(
1332 mem.getType(), stt.getBatchLvlRank());
1333 mem = memref::CastOp::create(
1334 rewriter, loc, fType,
1335 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1336 } else {
1337 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1338 }
1339 fields.push_back(mem);
1340 }
1341 return true;
1342 });
1343
1344 MutSparseTensorDescriptor desc(stt, fields);
1345 Value c0 = constantIndex(rewriter, loc, 0);
1346 Value c1 = constantIndex(rewriter, loc, 1);
1347 Value c2 = constantIndex(rewriter, loc, 2);
1348 Value posBack = c0; // index to the last value in the position array
1349 Value memSize = c1; // memory size for current array
1350
1351 Level trailCOOStart = stt.getAoSCOOStart();
1352 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1353 // Sets up SparseTensorSpecifier.
1354 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1355 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1356
1357 // Sets up the level size.
1358 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1359 desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1360 // We use a single AOS array to store the trailing COO, so there is only
1361 // one memory size to set for the entire COO section.
1362 if (lvl > trailCOOStart)
1363 continue;
1364
1365 // Sets up the memory size by reading the last value in position array.
1366 LevelType lt = stt.getLvlType(lvl);
1367 // Simply forwards the position index when this is a dense level.
1368 if (lt.isa<LevelFormat::Dense>()) {
1369 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1370 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1371 continue;
1372 }
1373 if (lt.isa<LevelFormat::Batch>()) {
1374 // Skips batch levels as it is not linearized.
1375 // FIXME: this assumes that every batch has the same number of nse, need
1376 // to be generalized to handle varied-size batches.
1377 continue;
1378 }
1379
1380 if (isWithPosLT(lt)) {
1381 assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1382 if (isLooseCompressedLT(lt)) {
1383 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1384 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1385 } else {
1386 assert(isCompressedLT(lt));
1387 posBack = memSize;
1388 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1389 }
1390 desc.setPosMemSize(rewriter, loc, lvl, memSize);
1391 // The last value in position array is the memory size for next level.
1392 // FIXME: this assumes that every batch has the same number of nse, need
1393 // to be generalized to handle varied-size batches.
1394 SmallVector<Value> batched(stt.getBatchLvlRank(),
1395 constantIndex(rewriter, loc, 0));
1396 batched.push_back(posBack);
1397 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1398 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1399 }
1400 assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1401 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1402 if (lvl == trailCOOStart) {
1403 Value cooSz = arith::MulIOp::create(
1404 rewriter, loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1405 desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1406 } else {
1407 desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1408 }
1409 }
1410 desc.setValMemSize(rewriter, loc, memSize);
1411
1412 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1413 return success();
1414 }
1415};
1416
1417struct SparseDisassembleOpConverter
1418 : public OpConversionPattern<DisassembleOp> {
1419 using OpConversionPattern::OpConversionPattern;
1420 SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1421 MLIRContext *context)
1422 : OpConversionPattern(typeConverter, context) {}
1423
1424 LogicalResult
1425 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1426 ConversionPatternRewriter &rewriter) const override {
1427 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1428 op.getTensor().getType());
1429 Location loc = op.getLoc();
1430 SmallVector<Value> retMem;
1431 SmallVector<Value> retLen;
1432 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1433 &retLen](FieldIndex fid,
1435 Level lvl, LevelType lt) -> bool {
1436 if (fKind == SparseTensorFieldKind::StorageSpec)
1437 return true;
1438 SparseTensorType stt(desc.getRankedTensorType());
1439 Value sz, src;
1441 if (fKind == SparseTensorFieldKind::ValMemRef) {
1442 sz = desc.getValMemSize(rewriter, loc);
1443 src = desc.getValMemRef();
1444 dst = genToMemref(rewriter, loc, op.getOutValues());
1445
1446 retMem.push_back(dst);
1447 Type valLenTp = op.getValLen().getType();
1448 retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1449 } else {
1450 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1451 fKind == SparseTensorFieldKind::CrdMemRef);
1452
1453 sz = fKind == SparseTensorFieldKind::PosMemRef
1454 ? desc.getPosMemSize(rewriter, loc, lvl)
1455 : desc.getCrdMemSize(rewriter, loc, lvl);
1456 src = desc.getMemRefField(fid);
1457 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1458 retMem.push_back(dst);
1459 // Retrieves the corresponding level length type.
1460 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1461 retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1462 }
1463 Value flatOut = dst;
1464 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1465 auto reassoc =
1466 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1467 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1468 }
1469 Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1470 Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1471 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1472 return true;
1473 });
1474
1475 // Converts MemRefs back to Tensors.
1476 SmallVector<Value> retValues = llvm::to_vector(
1477 llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1478 return bufferization::ToTensorOp::create(
1479 rewriter, loc, memref::getTensorTypeFromMemRefType(v.getType()),
1480 v);
1481 }));
1482 // Appends the actual memory length used in each buffer returned.
1483 retValues.append(retLen.begin(), retLen.end());
1484 rewriter.replaceOp(op, retValues);
1485 return success();
1486 }
1487};
1488
1489struct SparseNewConverter : public OpConversionPattern<NewOp> {
1490 using OpConversionPattern::OpConversionPattern;
1491 LogicalResult
1492 matchAndRewrite(NewOp op, OpAdaptor adaptor,
1493 ConversionPatternRewriter &rewriter) const override {
1494 Location loc = op.getLoc();
1495 const auto dstTp = getSparseTensorType(op.getResult());
1496 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1497 // are handled by rewriting.
1498 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1499 return failure();
1500
1501 // Implement as follows:
1502 // %reader = @createCheckedSparseTensorReader(%filename)
1503 // %nse = @getSparseTensorNSE(%reader)
1504 // %coo = bufferization.alloc_tensor an ordered COO with
1505 // dst dim ordering, size_hint = %nse
1506 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1507 // %values = sparse_tensor.values(%coo)
1508 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1509 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1510 // update storage specifier
1511 // @delSparseTensorReader(%reader)
1512 SmallVector<Value> dimSizesValues;
1513 Value dimSizesBuffer;
1514 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1515 dimSizesValues, dimSizesBuffer);
1516
1517 // Get the number of stored entries.
1518 const Type indexTp = rewriter.getIndexType();
1519 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1520 {indexTp}, {reader}, EmitCInterface::Off)
1521 .getResult(0);
1522
1523 // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1524 SmallVector<Value> lvlSizesValues;
1525 Value dim2lvlBuffer;
1526 Value lvl2dimBuffer;
1527 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1528 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1529
1530 // Construct allocation for each field.
1531 Value sizeHint = nse;
1532 SmallVector<Value> fields;
1533 createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1534 lvlSizesValues, fields);
1535
1536 // Read the COO tensor data.
1537 MutSparseTensorDescriptor desc(dstTp, fields);
1538 Value xs = desc.getAOSMemRef();
1539 Value ys = desc.getValMemRef();
1540 const Type boolTp = rewriter.getIntegerType(1);
1541 const Type elemTp = dstTp.getElementType();
1542 const Type crdTp = dstTp.getCrdType();
1543 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1546 Value isSorted =
1547 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1548 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1549 EmitCInterface::On)
1550 .getResult(0);
1551
1552 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1553 // data if the input elements aren't sorted yet.
1554 const Level lvlRank = dstTp.getLvlRank();
1555 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1556 Value kFalse = constantI1(rewriter, loc, false);
1557 Value notSorted = arith::CmpIOp::create(
1558 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1559 scf::IfOp ifOp =
1560 scf::IfOp::create(rewriter, loc, notSorted, /*else*/ false);
1561 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1562 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1563 SortOp::create(rewriter, loc, nse, xs, ValueRange{ys}, xPerm,
1564 rewriter.getIndexAttr(0),
1565 SparseTensorSortKind::HybridQuickSort);
1566 rewriter.setInsertionPointAfter(ifOp);
1567 }
1568
1569 // Set PosMemRef0[1] = nse.
1570 const Value c1 = constantIndex(rewriter, loc, 1);
1571 const Value posMemref0 = desc.getPosMemRef(0);
1572 const Type posTp = dstTp.getPosType();
1573 const Value posNse = genCast(rewriter, loc, nse, posTp);
1574 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1575
1576 // Update storage specifier.
1577 Value coordinatesSize = arith::MulIOp::create(
1578 rewriter, loc, nse, constantIndex(rewriter, loc, lvlRank));
1579 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1580 coordinatesSize);
1581 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1582 std::nullopt, nse);
1583
1584 // Release the sparse tensor reader.
1585 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1586 EmitCInterface::Off);
1587
1588 // Replace operation with resulting memrefs.
1589 rewriter.replaceOpWithMultiple(op, {fields});
1590 return success();
1591 }
1592};
1593
1594struct SparseHasRuntimeLibraryConverter
1595 : public OpConversionPattern<HasRuntimeLibraryOp> {
1596 using OpConversionPattern::OpConversionPattern;
1597 LogicalResult
1598 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1599 ConversionPatternRewriter &rewriter) const override {
1600 auto i1Type = rewriter.getI1Type();
1601 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1602 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1603 return success();
1604 }
1605};
1606
1607} // namespace
1608
1609//===----------------------------------------------------------------------===//
1610// Public method for populating conversion rules.
1611//===----------------------------------------------------------------------===//
1612
1613/// Populates the given patterns list with conversion rules required for
1614/// the sparsification of linear algebra operations.
1616 const TypeConverter &typeConverter, RewritePatternSet &patterns,
1617 bool createSparseDeallocs, bool enableBufferInitialization) {
1618 patterns.add<
1619 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1620 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1621 SparseCastConverter, SparseExtractSliceConverter,
1622 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1623 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1624 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1625 StorageSpecifierKind::DimOffset>,
1626 SparseSliceGetterOpConverter<ToSliceStrideOp,
1627 StorageSpecifierKind::DimStride>,
1628 SparseToPositionsConverter, SparseToCoordinatesConverter,
1629 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1630 SparseConvertConverter, SparseNewConverter,
1631 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1632 typeConverter, patterns.getContext());
1633 patterns.add<SparseTensorDeallocConverter>(
1634 typeConverter, patterns.getContext(), createSparseDeallocs);
1635 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1636 typeConverter, patterns.getContext(), enableBufferInitialization);
1637}
return success()
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:60
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
bool isCompressedLT(LevelType lt)
Definition Enums.h:415
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition Enums.h:418
unsigned FieldIndex
The type of field indices.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
Definition Enums.h:413
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326