MLIR 23.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor types and primitives to actual compiler
10// visible buffers and actual compiler IR that implements these primitives on
11// the selected sparse tensor storage schemes. This pass provides an alternative
12// to the SparseTensorConversion pass, eliminating the dependence on a runtime
13// support library (other than for file I/O), and providing many more
14// opportunities for subsequent compiler optimization of the generated code.
15//
16//===----------------------------------------------------------------------===//
17
18#include "Utils/CodegenUtils.h"
20
32#include "llvm/ADT/SmallVectorExtras.h"
33
34#include <optional>
35
36using namespace mlir;
37using namespace mlir::sparse_tensor;
38
39//===----------------------------------------------------------------------===//
40// Helper methods.
41//===----------------------------------------------------------------------===//
42
43/// Flatten the given value ranges into a single vector of values.
46 for (const auto &vals : values)
47 llvm::append_range(result, vals);
48 return result;
49}
50
51/// Generates a load with proper `index` typing.
52static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
53 idx = genCast(builder, loc, idx, builder.getIndexType());
54 return memref::LoadOp::create(builder, loc, mem, idx);
55}
56
57/// Generates a store with proper `index` typing and proper value.
58static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
59 Value idx) {
60 idx = genCast(builder, loc, idx, builder.getIndexType());
61 val = genCast(builder, loc, val,
62 cast<ShapedType>(mem.getType()).getElementType());
63 memref::StoreOp::create(builder, loc, val, mem, idx);
64}
65
66/// Creates a straightforward counting for-loop.
67static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
69 Value lower = Value()) {
70 Type indexType = builder.getIndexType();
71 if (!lower)
72 lower = constantZero(builder, loc, indexType);
73 Value one = constantOne(builder, loc, indexType);
74 scf::ForOp forOp =
75 scf::ForOp::create(builder, loc, lower, upper, one, fields);
76 for (unsigned i = 0, e = fields.size(); i < e; i++)
77 fields[i] = forOp.getRegionIterArg(i);
78 builder.setInsertionPointToStart(forOp.getBody());
79 return forOp;
80}
81
82/// Creates a push back operation.
83static void createPushback(OpBuilder &builder, Location loc,
85 SparseTensorFieldKind kind, std::optional<Level> lvl,
86 Value value, Value repeat = Value()) {
87 Type etp = desc.getMemRefElementType(kind, lvl);
88 Value field = desc.getMemRefField(kind, lvl);
89 StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
90
91 auto pushBackOp = PushBackOp::create(
92 builder, loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl),
93 field, genCast(builder, loc, value, etp), repeat);
94
95 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
96 desc.setSpecifierField(builder, loc, specFieldKind, lvl,
97 pushBackOp.getNewSize());
98}
99
100/// Generates code that allocates a sparse storage scheme for given rank.
101static void allocSchemeForRank(OpBuilder &builder, Location loc,
102 MutSparseTensorDescriptor desc, Level startLvl) {
103 const SparseTensorType stt(desc.getRankedTensorType());
104 Value linear = constantIndex(builder, loc, 1);
105 const Level lvlRank = stt.getLvlRank();
106 for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
107 const auto lt = stt.getLvlType(lvl);
108 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
109 // Append linear x positions, initialized to zero. Since each compressed
110 // dimension initially already has a single zero entry, this maintains
111 // the desired "linear + 1" length property at all times. For loose
112 // compression, we multiply linear by two in order to append both the
113 // lo/hi positions.
114 Value posZero = constantZero(builder, loc, stt.getPosType());
115 if (isLooseCompressedLT(lt)) {
116 Value two = constantIndex(builder, loc, 2);
117 linear = arith::MulIOp::create(builder, loc, linear, two);
118 }
119 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
120 /*value=*/posZero, /*repeat=*/linear);
121 return;
122 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
123 return; // nothing to do
124 }
125 // Keep compounding the size, but nothing needs to be initialized
126 // at this level. We will eventually reach a compressed level or
127 // otherwise the values array for the from-here "all-dense" case.
128 assert(isDenseLT(lt));
129 Value size = desc.getLvlSize(builder, loc, lvl);
130 linear = arith::MulIOp::create(builder, loc, linear, size);
131 }
132 // Reached values array so prepare for an insertion.
133 Value valZero = constantZero(builder, loc, stt.getElementType());
135 std::nullopt, /*value=*/valZero, /*repeat=*/linear);
136}
137
138/// Creates allocation operation.
140 MemRefType memRefType, Value sz,
141 bool enableInit) {
142 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
143 Type elemType = memRefType.getElementType();
144 if (enableInit) {
145 Value fillValue = constantZero(builder, loc, elemType);
146 linalg::FillOp::create(builder, loc, fillValue, buffer);
147 }
148 return buffer;
149}
150
151/// Creates the dim sizes array, filling in from dynamic sizes.
152static void createDimSizes(OpBuilder &builder, Location loc,
153 SparseTensorType stt, ValueRange dynSizes,
154 /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
155 const Dimension dimRank = stt.getDimRank();
156 dimSizesValues.clear();
157 dimSizesValues.reserve(dimRank);
158 unsigned i = 0;
159 for (const Size sz : stt.getDimShape())
160 dimSizesValues.push_back(ShapedType::isDynamic(sz)
161 ? dynSizes[i++]
162 : constantIndex(builder, loc, sz));
163}
164
165/// Creates allocation for each field in sparse tensor type. Note that
166/// for all dynamic memrefs in the sparse tensor stroage layout, the
167/// memory size is really the capacity of the "vector", while the actual
168/// size resides in the sizes array.
169static void createAllocFields(OpBuilder &builder, Location loc,
170 SparseTensorType stt, bool enableInit,
171 Value sizeHint,
172 SmallVectorImpl<Value> &lvlSizesValues,
173 /*out*/ SmallVectorImpl<Value> &fields) {
174 Level lvlRank = stt.getLvlRank();
175 // Set up some heuristic sizes. We try to set the initial
176 // size based on available information. Otherwise we just
177 // initialize a few elements to start the reallocation chain.
178 // TODO: refine this
179 Value posHeuristic, crdHeuristic, valHeuristic;
180 if (stt.isAllDense()) {
181 valHeuristic = lvlSizesValues[0];
182 for (Level lvl = 1; lvl < lvlRank; lvl++)
183 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
184 lvlSizesValues[lvl]);
185 } else if (sizeHint) {
186 if (stt.getAoSCOOStart() == 0) {
187 posHeuristic = constantIndex(builder, loc, 2);
188 crdHeuristic = arith::MulIOp::create(
189 builder, loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
190 } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
191 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
192 constantIndex(builder, loc, 1));
193 crdHeuristic = sizeHint;
194 } else {
195 posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
196 }
197 valHeuristic = sizeHint;
198 } else {
199 posHeuristic = crdHeuristic = valHeuristic =
200 constantIndex(builder, loc, 16);
201 }
202 // Initializes all fields. An initial storage specifier and allocated
203 // positions/coordinates/values memrefs (with heuristic capacity).
205 stt,
206 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
207 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
208 Level /*lvl*/, LevelType /*lt*/) -> bool {
209 assert(fields.size() == fIdx);
210 Value field;
211 switch (fKind) {
213 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
214 break;
216 field = createAllocation(builder, loc, cast<MemRefType>(fType),
217 posHeuristic, enableInit);
218 break;
220 field = createAllocation(builder, loc, cast<MemRefType>(fType),
221 crdHeuristic, enableInit);
222 break;
224 field = createAllocation(builder, loc, cast<MemRefType>(fType),
225 valHeuristic, enableInit);
226 break;
227 }
228 assert(field);
229 fields.push_back(field);
230 // Returns true to continue the iteration.
231 return true;
232 });
233 // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
234 // and gives all position fields an initial zero entry, so that it is
235 // easier to maintain the "linear + 1" length property.
236 MutSparseTensorDescriptor desc(stt, fields);
237 Value posZero = constantZero(builder, loc, stt.getPosType());
238 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
239 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
240 const auto lt = stt.getLvlType(lvl);
241 if (isCompressedLT(lt) || isLooseCompressedLT(lt))
242 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
243 /*value=*/posZero);
244 }
245 allocSchemeForRank(builder, loc, desc, /*rank=*/0);
246}
247
248/// Helper method that generates block specific to compressed case:
249///
250/// // given: parentPos = posCursor[lvl-1]
251/// pstart = desc.positions[lvl][parentPos]
252/// pstop = desc.positions[lvl][parentPos+1]
253/// plast = pstop - 1
254/// msz = desc.coordinates[lvl].size()
255/// if (pstart < pstop) {
256/// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
257/// } else { // first insertion
258/// isPresent = false
259/// desc.positions[lvl][parentPos] = msz
260/// }
261/// if (isPresent) { // coordinate is already present
262/// pnext = plast
263/// } else {
264/// desc.coordinates[lvl].push_back(lvlCoords[lvl])
265/// desc.positions[lvl][parentPos+1] = msz+1
266/// pnext = msz
267/// <prepare level lvl+1>
268/// }
269/// posCursor[lvl] = pnext
272 Value /*unused*/, Value parentPos, Level lvl) {
273 const SparseTensorType stt(desc.getRankedTensorType());
274 const Level lvlRank = stt.getLvlRank();
275 assert(lvl < lvlRank && "Level is out of bounds");
276 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
277 "Level-rank mismatch");
278 SmallVector<Type> types;
279 Type indexType = builder.getIndexType();
280 Type boolType = builder.getIntegerType(1);
281 unsigned crdFidx;
282 unsigned crdStride;
283 std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
284 const Value one = constantIndex(builder, loc, 1);
285 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
286 const Value positionsAtLvl = desc.getPosMemRef(lvl);
287 const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
288 const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
289 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
290 const Value crdStrideC =
291 crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
292 const Value msz =
293 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
294 : crdMsz;
295 const Value plast = arith::SubIOp::create(
296 builder, loc, genCast(builder, loc, pstop, indexType), one);
297 // Conditional expression.
298 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
299 pstart, pstop);
300 types.push_back(boolType);
301 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt, /*else*/ true);
302 types.pop_back();
303 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
304 Value crd = genLoad(
305 builder, loc, desc.getMemRefField(crdFidx),
306 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
307 : plast);
308 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
309 genCast(builder, loc, crd, indexType),
310 lvlCoords[lvl]);
311 scf::YieldOp::create(builder, loc, eq);
312 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
313 if (lvl > 0)
314 genStore(builder, loc, msz, positionsAtLvl, parentPos);
315 scf::YieldOp::create(builder, loc, constantI1(builder, loc, false));
316 builder.setInsertionPointAfter(ifOp1);
317 // If present construct. Note that for a non-unique dimension level, we
318 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
319 //
320 // TODO: generate less temporary IR?
321 //
322 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
323 types.push_back(desc.getField(i).getType());
324 types.push_back(indexType);
325 const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
326 : constantI1(builder, loc, false);
327 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p, /*else*/ true);
328 // If present (fields unaffected, update pnext to plast).
329 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
330
331 // FIXME: This does not looks like a clean way, but probably the most
332 // efficient way.
333 desc.getFields().push_back(plast);
334 scf::YieldOp::create(builder, loc, desc.getFields());
335 desc.getFields().pop_back();
336
337 // If !present (changes fields, update pnext).
338 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
339 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
340 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
341 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
342 /*value=*/lvlCoords[lvl]);
343 // Prepare the next level "as needed".
344 if ((lvl + 1) < lvlRank)
345 allocSchemeForRank(builder, loc, desc, lvl + 1);
346
347 desc.getFields().push_back(msz);
348 scf::YieldOp::create(builder, loc, desc.getFields());
349 desc.getFields().pop_back();
350
351 // Update fields and return next pos.
352 builder.setInsertionPointAfter(ifOp2);
353 unsigned o = 0;
354 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
355 desc.setField(i, ifOp2.getResult(o++));
356 return ifOp2.getResult(o);
357}
358
359/// Generates insertion finalization code.
360static void genEndInsert(OpBuilder &builder, Location loc,
362 const SparseTensorType stt(desc.getRankedTensorType());
363 const Level lvlRank = stt.getLvlRank();
364 for (Level lvl = 0; lvl < lvlRank; lvl++) {
365 const auto lt = stt.getLvlType(lvl);
366 if (isCompressedLT(lt)) {
367 // Compressed dimensions need a position cleanup for all entries
368 // that were not visited during the insertion pass.
369 //
370 // TODO: avoid cleanup and keep compressed scheme consistent at all
371 // times?
372 //
373 if (lvl > 0) {
374 Type posType = stt.getPosType();
375 Value posMemRef = desc.getPosMemRef(lvl);
376 Value hi = desc.getPosMemSize(builder, loc, lvl);
377 Value zero = constantIndex(builder, loc, 0);
378 Value one = constantIndex(builder, loc, 1);
379 // Vector of only one, but needed by createFor's prototype.
380 SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
381 scf::ForOp loop = createFor(builder, loc, hi, inits, one);
382 Value i = loop.getInductionVar();
383 Value oldv = loop.getRegionIterArg(0);
384 Value newv = genLoad(builder, loc, posMemRef, i);
385 Value posZero = constantZero(builder, loc, posType);
386 Value cond = arith::CmpIOp::create(
387 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
388 scf::IfOp ifOp = scf::IfOp::create(builder, loc, TypeRange(posType),
389 cond, /*else*/ true);
390 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
391 genStore(builder, loc, oldv, posMemRef, i);
392 scf::YieldOp::create(builder, loc, oldv);
393 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
394 scf::YieldOp::create(builder, loc, newv);
395 builder.setInsertionPointAfter(ifOp);
396 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
397 builder.setInsertionPointAfter(loop);
398 }
399 } else {
400 assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
401 isNOutOfMLT(lt));
402 }
403 }
404}
405
406/// Generates a subview into the sizes.
407static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
408 Value sz) {
409 auto memTp = llvm::cast<MemRefType>(mem.getType());
410 // For higher-dimensional memrefs, we assume that the innermost
411 // dimension is always of the right size.
412 // TODO: generate complex truncating view here too?
413 if (memTp.getRank() > 1)
414 return mem;
415 // Truncate linear memrefs to given size.
416 return memref::SubViewOp::create(
417 builder, loc,
418 MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
419 mem, ValueRange{}, ValueRange{sz}, ValueRange{},
420 ArrayRef<int64_t>{0}, // static offset
421 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
422 ArrayRef<int64_t>{1}) // static stride
423 .getResult();
424}
425
426/// Creates the reassociation array.
428getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
429 SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
430 // Create reassociation in the form:
431 // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
432 for (unsigned i = 0; i < batchLvls; i++)
433 ret[i].push_back(i);
434
435 for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
436 ret.back().push_back(i);
437
438 return ret;
439}
440
441//===----------------------------------------------------------------------===//
442// Codegen rules.
443//===----------------------------------------------------------------------===//
444
445namespace {
446
447/// Helper class to help lowering sparse_tensor.insert operation.
448class SparseInsertGenerator
449 : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
450public:
451 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
452 bool genCall)
453 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
454
455 /// Generates code along an insertion path without the need for a "cursor".
456 /// This current insertion strategy comes at the expense of some testing
457 /// overhead for each insertion. The strategy will be optimized later for
458 /// common insertion patterns. The current insertion strategy also assumes
459 /// insertions occur in "a reasonable order" that enables building the
460 /// storage scheme in an appending/inserting kind of fashion (i.e. no
461 /// in-between insertions that need data movement). The implementation
462 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
463 ///
464 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
465 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
466 OpBuilder &builder, Location loc) {
467 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
468 const Level lvlRank = stt.getLvlRank();
469 // Extract fields and coordinates from args.
470 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
471 MutSparseTensorDescriptor desc(stt, fields);
472 const SmallVector<Value> coords =
473 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
474 Value value = args.back();
475 Value parentPos = constantZero(builder, loc, builder.getIndexType());
476 // Generate code for every level.
477 for (Level lvl = 0; lvl < lvlRank; lvl++) {
478 const auto lt = stt.getLvlType(lvl);
479 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
480 // Create:
481 // if (!present) {
482 // coordinates[lvl].push_back(coords[lvl])
483 // <update positions and prepare level lvl + 1>
484 // }
485 // positions[lvl] = coordinates.size() - 1
486 // <insert @ positions[lvl] at next level lvl + 1>
487 if (isLooseCompressedLT(lt)) {
488 Value two = constantIndex(builder, loc, 2);
489 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
490 }
491 parentPos =
492 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
493 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
494 // Create:
495 // coordinates[lvl].push_back(coords[lvl])
496 // positions[lvl] = positions[lvl-1]
497 // <insert @ positions[lvl] at next level lvl + 1>
498 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
499 lvl, /*value=*/coords[lvl]);
500 } else {
501 assert(isDenseLT(lt));
502 // Construct the new position as:
503 // positions[lvl] = size * positions[lvl-1] + coords[lvl]
504 // <insert @ positions[lvl] at next level lvl + 1>
505 Value size = desc.getLvlSize(builder, loc, lvl);
506 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
507 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
508 }
509 }
510 // Reached the actual value append/insert.
511 if (!stt.isDenseLvl(lvlRank - 1))
512 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
513 std::nullopt, value);
514 else
515 genStore(builder, loc, value, desc.getValMemRef(), parentPos);
516 return fields;
517 }
518
519 std::string getMangledFuncName() {
520 // The mangled name of the function has this format:
521 // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
522 constexpr const char kInsertFuncNamePrefix[] = "_insert_";
523 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
524 SmallString<32> nameBuffer;
525 llvm::raw_svector_ostream nameOstream(nameBuffer);
526 nameOstream << kInsertFuncNamePrefix;
527 const Level lvlRank = stt.getLvlRank();
528 for (Level l = 0; l < lvlRank; l++) {
529 std::string lvlType = toMLIRString(stt.getLvlType(l));
530 // Replace/remove punctuations in level properties.
531 std::replace_if(
532 lvlType.begin(), lvlType.end(),
533 [](char c) { return c == '(' || c == ','; }, '_');
534 llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
535 nameOstream << lvlType << "_";
536 }
537 // Static dim sizes are used in the generated code while dynamic sizes are
538 // loaded from the dimSizes buffer. This is the reason for adding the shape
539 // to the function name.
540 for (const auto sz : stt.getDimShape())
541 nameOstream << sz << "_";
542 // Permutation information is also used in generating insertion.
543 if (!stt.isIdentity())
544 nameOstream << stt.getDimToLvl() << "_";
545 nameOstream << stt.getElementType() << "_";
546 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
547 return nameOstream.str().str();
548 }
549
550private:
551 TensorType rtp;
552};
553
554/// Sparse tensor storage conversion rule for returns.
555class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
556public:
557 using OpConversionPattern::OpConversionPattern;
558 LogicalResult
559 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override {
561 // Create a return with the flattened value extracted from sparse tensors.
562 rewriter.replaceOpWithNewOp<func::ReturnOp>(
563 op, flattenValues(adaptor.getOperands()));
564 return success();
565 }
566};
567
568/// Sparse tensor storage conversion rule for calls.
569class SparseCallConverter : public OpConversionPattern<func::CallOp> {
570public:
571 // The default CallOp converter can not handle 1:N type conversion.
572 using OpConversionPattern::OpConversionPattern;
573 LogicalResult
574 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter) const override {
576 Location loc = op.getLoc();
577 // In case of:
578 // sparse_tensor, f, sparse_tensor = call @foo(...)
579 // ==>
580 // memref..., f, memref = call @foo(...) replace with
581 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
582 SmallVector<Type> finalRetTy;
583 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
584 return failure();
585
586 // (1) Generates new call with flattened return value.
587 auto newCall =
588 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
589 flattenValues(adaptor.getOperands()));
590 // (2) Gather sparse tensor returns.
591 SmallVector<SmallVector<Value>> packedResultVals;
592 // Tracks the offset of current return value (of the original call)
593 // relative to the new call (after sparse tensor flattening);
594 unsigned retOffset = 0;
595 // Temporal buffer to hold the flattened list of type for
596 // a sparse tensor.
597 SmallVector<Type> sparseFlat;
598 for (auto ret : op.getResults()) {
599 assert(retOffset < newCall.getNumResults());
600 auto retType = ret.getType();
601 if (failed(typeConverter->convertType(retType, sparseFlat)))
602 llvm_unreachable("Failed to convert type in sparse tensor codegen");
603
604 // Converted types can not be empty when the type conversion succeed.
605 assert(!sparseFlat.empty());
606 if (sparseFlat.size() > 1) {
607 auto flatSize = sparseFlat.size();
608 packedResultVals.emplace_back();
609 llvm::append_range(packedResultVals.back(),
610 newCall.getResults().slice(retOffset, flatSize));
611 retOffset += flatSize;
612 } else {
613 // If this is an 1:1 conversion, no need for casting.
614 packedResultVals.emplace_back();
615 packedResultVals.back().push_back(newCall.getResult(retOffset));
616 retOffset++;
617 }
618 sparseFlat.clear();
619 }
620
621 assert(packedResultVals.size() == op.getNumResults());
622 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
623 return success();
624 }
625};
626
627/// Sparse codegen rule for level accesses.
628class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
629public:
630 using OpConversionPattern::OpConversionPattern;
631 LogicalResult
632 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 std::optional<int64_t> lvl = op.getConstantLvlIndex();
635 RankedTensorType srcType = op.getSource().getType();
636 if (!lvl || !getSparseTensorEncoding(srcType))
637 return failure();
638
639 auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
640 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
641
642 rewriter.replaceOp(op, sz);
643 return success();
644 }
645};
646
647// TODO: use a new SortCOO operation here instead of reusing convert op.
648struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
649 using OpConversionPattern::OpConversionPattern;
650 LogicalResult
651 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
652 ConversionPatternRewriter &rewriter) const override {
653 Location loc = op.getLoc();
654 MLIRContext *ctx = op.getContext();
655
656 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
657 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
658
659 // Should have been verified.
660 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
661 dstStt.isCOOType() && srcStt.isCOOType());
662 assert(dstStt.hasSameDimToLvl(srcStt));
663
664 // We don't need a mutable descriptor here as we perform sorting in-place.
665 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
666 op.getInputCoo().getType());
667 auto nnz = desc.getValMemSize(rewriter, op.getLoc());
668 auto crd = desc.getAOSMemRef();
669 auto val = desc.getValMemRef();
670
671 // Otherwise we need another data shuffle and a non-identity map.
672 assert(dstStt.hasSameDimToLvl(srcStt));
673 (void)dstStt; // to silence warning when assertion is disabled
674
675 auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
676
677 SortOp::create(rewriter, loc, nnz, crd, ValueRange{val}, id,
678 rewriter.getIndexAttr(0), op.getAlgorithm());
679
680 // Since we do in-place sorting, the destinate tensor will have the same set
681 // of memrefs as the source tensor.
682 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
683 return success();
684 }
685};
686
687template <typename Op, StorageSpecifierKind kind>
688class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
689public:
690 using OpConversionPattern<Op>::OpConversionPattern;
691 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
692
693 LogicalResult
694 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter) const override {
696 // Simply lowers to specifer.get <field> operation.
697 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
698 op.getSlice().getType());
699 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
700 op.getDim().getZExtValue());
701
702 rewriter.replaceOp(op, v);
703 return success();
704 }
705};
706
707/// Sparse codegen rule for trivial tensor casts.
708class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
709public:
710 using OpConversionPattern::OpConversionPattern;
711 LogicalResult
712 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
713 ConversionPatternRewriter &rewriter) const override {
714 // Only rewrite identically annotated source/dest.
715 auto encDst = getSparseTensorEncoding(op.getType());
716 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
717 if (!encDst || encDst != encSrc)
718 return failure();
719 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
720 return success();
721 }
722};
723
724class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
725public:
726 using OpConversionPattern::OpConversionPattern;
727 LogicalResult
728 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
729 ConversionPatternRewriter &rewriter) const override {
730 // Simply fold the operation.
731 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
732 return success();
733 }
734};
735
736/// Sparse codegen rule for the alloc operator.
737class SparseTensorAllocConverter
738 : public OpConversionPattern<bufferization::AllocTensorOp> {
739public:
740 using OpConversionPattern::OpConversionPattern;
741 SparseTensorAllocConverter(const TypeConverter &typeConverter,
742 MLIRContext *context, bool enableInit)
743 : OpConversionPattern(typeConverter, context),
744 enableBufferInitialization(enableInit) {}
745
746 LogicalResult
747 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter) const override {
749 const auto resType = getSparseTensorType(op);
750 if (!resType.hasEncoding())
751 return failure();
752
753 Location loc = op.getLoc();
754 // Deal with copy.
755 if (op.getCopy()) {
757 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
758 SmallVector<Value> fields;
759 fields.reserve(desc.getNumFields());
760 // Memcpy on memref fields.
761 for (auto field : desc.getMemRefFields()) {
762 auto memrefTp = cast<MemRefType>(field.getType());
763 auto size = memref::DimOp::create(rewriter, loc, field, 0);
764 auto copied =
765 memref::AllocOp::create(rewriter, loc, memrefTp, ValueRange{size});
766 memref::CopyOp::create(rewriter, loc, field, copied);
767 fields.push_back(copied);
768 }
769 // Reuses specifier.
770 fields.push_back(desc.getSpecifier());
771 assert(fields.size() == desc.getNumFields());
772 rewriter.replaceOpWithMultiple(op, {fields});
773 return success();
774 }
775
776 if (!resType.isIdentity()) {
777 return rewriter.notifyMatchFailure(
778 op, "try run --sparse-reinterpret-map before codegen");
779 }
780 // Level size equals to dimension size since lvl2dim map is an identity map.
781 SmallVector<Value> lvlSizesValues;
782 createDimSizes(rewriter, loc, resType,
783 flattenValues(adaptor.getDynamicSizes()),
784 /*dimSizesValues=*/lvlSizesValues);
785
786 // Construct allocation for each field.
787 Value sizeHint = op.getSizeHint();
788 SmallVector<Value> fields;
789 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
790 sizeHint, lvlSizesValues, fields);
791
792 // Replace operation with resulting memrefs.
793 rewriter.replaceOpWithMultiple(op, {fields});
794 return success();
795 }
796
797private:
798 bool enableBufferInitialization;
799};
800
801/// Sparse codegen rule for the empty tensor operator.
802class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
803public:
804 using OpConversionPattern::OpConversionPattern;
805 SparseTensorEmptyConverter(const TypeConverter &typeConverter,
806 MLIRContext *context, bool enableInit)
807 : OpConversionPattern(typeConverter, context),
808 enableBufferInitialization(enableInit) {}
809
810 LogicalResult
811 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter) const override {
813 const auto resType = getSparseTensorType(op);
814 if (!resType.hasEncoding())
815 return failure();
816
817 if (!resType.isIdentity()) {
818 return rewriter.notifyMatchFailure(
819 op, "try run --sparse-reinterpret-map before codegen");
820 }
821
822 Location loc = op.getLoc();
823 // Level size equals to dimension size since lvl2dim map is an identity map.
824 SmallVector<Value> lvlSizesValues;
825 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
826 /*dimSizesValues=*/lvlSizesValues);
827 // Construct allocation for each field.
828 Value sizeHint; // none
829 SmallVector<Value> fields;
830 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
831 sizeHint, lvlSizesValues, fields);
832
833 // Replace operation with resulting memrefs.
834 rewriter.replaceOpWithMultiple(op, {fields});
835 return success();
836 }
837
838private:
839 bool enableBufferInitialization;
840};
841
842/// Sparse codegen rule for the dealloc operator.
843class SparseTensorDeallocConverter
844 : public OpConversionPattern<bufferization::DeallocTensorOp> {
845public:
846 using OpConversionPattern::OpConversionPattern;
847 SparseTensorDeallocConverter(const TypeConverter &typeConverter,
848 MLIRContext *context, bool createDeallocs)
849 : OpConversionPattern(typeConverter, context),
850 createDeallocs(createDeallocs) {}
851
852 LogicalResult
853 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 auto enc = getSparseTensorEncoding(op.getTensor().getType());
856 if (!enc)
857 return failure();
858
859 // If user requests not to deallocate sparse tensors, simply erase the
860 // operation.
861 if (createDeallocs) {
862 // Replace the sparse tensor deallocation with field deallocations.
863 Location loc = op.getLoc();
865 adaptor.getTensor(),
866 cast<RankedTensorType>(op.getTensor().getType()));
867 for (auto input : desc.getMemRefFields())
868 // Deallocate every buffer used to store the sparse tensor handler.
869 memref::DeallocOp::create(rewriter, loc, input);
870 }
871 rewriter.eraseOp(op);
872 return success();
873 }
874
875private:
876 const bool createDeallocs;
877};
878
879/// Sparse codegen rule for tensor rematerialization.
880class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
881public:
882 using OpConversionPattern::OpConversionPattern;
883 LogicalResult
884 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
885 ConversionPatternRewriter &rewriter) const override {
886 // Prepare descriptor.
887 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
888 op.getTensor().getType());
889 // Generate optional insertion finalization code.
890 if (op.getHasInserts())
891 genEndInsert(rewriter, op.getLoc(), desc);
892 // Replace operation with resulting memrefs.
893 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
894 return success();
895 }
896};
897
898/// Sparse codegen rule for the expand op.
899class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
900public:
901 using OpConversionPattern::OpConversionPattern;
902 LogicalResult
903 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
904 ConversionPatternRewriter &rewriter) const override {
905 if (!getSparseTensorEncoding(op.getTensor().getType()))
906 return failure();
907 Location loc = op->getLoc();
908 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
909 op.getTensor().getType());
910 const auto srcType = getSparseTensorType(op.getTensor());
911 Type eltType = srcType.getElementType();
912 Type boolType = rewriter.getIntegerType(1);
913 Type idxType = rewriter.getIndexType();
914 // All initialization should be done on entry of the loop nest.
915 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
916
917 // Determine the size for access expansion (always the innermost stored
918 // level size).
919 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
920 // Generate a memref for `sz` elements of type `t`.
921 const auto genAlloc = [&](Type t) {
922 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
923 return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz});
924 };
925 // Allocate temporary buffers for values/filled-switch and added.
926 // We do not use stack buffers for this, since the expanded size may
927 // be rather large (as it envelops a single expanded dense dimension).
928 Value values = genAlloc(eltType);
929 Value filled = genAlloc(boolType);
930 Value added = genAlloc(idxType);
931 Value zero = constantZero(rewriter, loc, idxType);
932 // Reset the values/filled-switch to all-zero/false. Note that this
933 // introduces an O(N) operation into the computation, but this reset
934 // operation is amortized over the innermost loops for the access
935 // pattern expansion. As noted in the operation doc, we would like
936 // to amortize this setup cost even between kernels.
937 linalg::FillOp::create(rewriter, loc,
938 ValueRange{constantZero(rewriter, loc, eltType)},
939 ValueRange{values});
940 linalg::FillOp::create(rewriter, loc,
941 ValueRange{constantZero(rewriter, loc, boolType)},
942 ValueRange{filled});
943 // Replace expansion op with these buffers and initial coordinate.
944 assert(op.getNumResults() == 4);
945 rewriter.replaceOp(op, {values, filled, added, zero});
946 return success();
947 }
948};
949
950/// Sparse codegen rule for the compress operator.
951class SparseCompressConverter : public OpConversionPattern<CompressOp> {
952public:
953 using OpConversionPattern::OpConversionPattern;
954 LogicalResult
955 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter) const override {
957 Location loc = op->getLoc();
958 SmallVector<Value> fields;
959 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
960 op.getTensor().getType());
961 Value values = llvm::getSingleElement(adaptor.getValues());
962 Value filled = llvm::getSingleElement(adaptor.getFilled());
963 Value added = llvm::getSingleElement(adaptor.getAdded());
964 Value count = llvm::getSingleElement(adaptor.getCount());
965 const SparseTensorType dstType(desc.getRankedTensorType());
966 Type eltType = dstType.getElementType();
967
968 // If the innermost level is ordered, we need to sort the coordinates
969 // in the "added" array prior to applying the compression.
970 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
971 SortOp::create(rewriter, loc, count, added, ValueRange{},
972 rewriter.getMultiDimIdentityMap(1),
973 rewriter.getIndexAttr(0),
974 SparseTensorSortKind::HybridQuickSort);
975 // While performing the insertions, we also need to reset the elements
976 // of the values/filled-switch by only iterating over the set elements,
977 // to ensure that the runtime complexity remains proportional to the
978 // sparsity of the expanded access pattern.
979 //
980 // Generate
981 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
982 // crd = added[i];
983 // value = values[crd];
984 // insert({lvlCoords, crd}, value);
985 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
986 // values[crd] = 0;
987 // filled[crd] = false;
988 // yield new_memrefs
989 // }
990 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
991 Value i = loop.getInductionVar();
992
993 Value crd = genLoad(rewriter, loc, added, i);
994 Value value = genLoad(rewriter, loc, values, crd);
995 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
996 SmallVector<Type> flatSpTensorTps = llvm::map_to_vector(
997 desc.getFields(), [](Value v) { return v.getType(); });
998 SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
999 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1000 params.push_back(crd);
1001 params.push_back(value);
1002 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1003 params, /*genCall=*/true);
1004 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1005 genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1006 genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1007 scf::YieldOp::create(rewriter, loc, insertRet);
1008
1009 rewriter.setInsertionPointAfter(loop);
1010 // Deallocate the buffers on exit of the full loop nest.
1011 Operation *parent = getTop(op);
1012 rewriter.setInsertionPointAfter(parent);
1013 memref::DeallocOp::create(rewriter, loc, values);
1014 memref::DeallocOp::create(rewriter, loc, filled);
1015 memref::DeallocOp::create(rewriter, loc, added);
1016 // Replace operation with resulting memrefs.
1017 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1018 return success();
1019 }
1020};
1021
1022/// Sparse codegen rule for the insert operator.
1023class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1024public:
1025 using OpConversionPattern::OpConversionPattern;
1026 LogicalResult
1027 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter) const override {
1029 auto stt = getSparseTensorType(op.getDest());
1030 if (!stt.hasEncoding())
1031 return failure();
1032 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1033
1034 Location loc = op.getLoc();
1035 auto desc =
1036 getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1037 TypeRange flatSpTensorTps = desc.getFields().getTypes();
1038 SmallVector<Value> params = llvm::to_vector(desc.getFields());
1039 SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1040 params.append(flatIndices.begin(), flatIndices.end());
1041 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1042 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1043 params, /*genCall=*/true);
1044 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1045 // Replace operation with resulting memrefs.
1046 rewriter.replaceOpWithMultiple(op, {ret});
1047 return success();
1048 }
1049};
1050
1051/// Sparse codegen rule for position accesses.
1052class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1053public:
1054 using OpAdaptor = ToPositionsOp::Adaptor;
1055 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1056 LogicalResult
1057 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1058 ConversionPatternRewriter &rewriter) const override {
1059 // Replace the requested position access with corresponding field.
1060 // The view is restricted to the actual size to ensure clients
1061 // of this operation truly observe size, not capacity!
1062 Location loc = op.getLoc();
1063 Level lvl = op.getLevel();
1064 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1065 op.getTensor().getType());
1066 auto mem = desc.getPosMemRef(lvl);
1067 auto size = desc.getPosMemSize(rewriter, loc, lvl);
1068 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1069 return success();
1070 }
1071};
1072
1073/// Sparse codegen rule for accessing the coordinates arrays.
1074class SparseToCoordinatesConverter
1075 : public OpConversionPattern<ToCoordinatesOp> {
1076public:
1077 using OpAdaptor = ToCoordinatesOp::Adaptor;
1078 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1079 LogicalResult
1080 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1081 ConversionPatternRewriter &rewriter) const override {
1082 // Replace the requested coordinates access with corresponding field.
1083 // The view is restricted to the actual size to ensure clients
1084 // of this operation truly observe size, not capacity!
1085 Location loc = op.getLoc();
1086 Level lvl = op.getLevel();
1087 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1088 op.getTensor().getType());
1089 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1090 if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1091 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1092 mem = genSliceToSize(rewriter, loc, mem, size);
1093 }
1094 rewriter.replaceOp(op, mem);
1095 return success();
1096 }
1097};
1098
1099/// Sparse codegen rule for accessing the linear coordinates buffer.
1100class SparseToCoordinatesBufferConverter
1101 : public OpConversionPattern<ToCoordinatesBufferOp> {
1102public:
1103 using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
1104 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1105 LogicalResult
1106 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter) const override {
1108 // Replace the requested coordinates access with corresponding field.
1109 // The view is restricted to the actual size to ensure clients
1110 // of this operation truly observe size, not capacity!
1111 Location loc = op.getLoc();
1112 Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1113 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1114 op.getTensor().getType());
1115 auto mem = desc.getAOSMemRef();
1116 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1117 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1118 return success();
1119 }
1120};
1121
1122/// Sparse codegen rule for value accesses.
1123class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1124public:
1125 using OpAdaptor = ToValuesOp::Adaptor;
1126 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1127 LogicalResult
1128 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1129 ConversionPatternRewriter &rewriter) const override {
1130 // Replace the requested values access with corresponding field.
1131 // The view is restricted to the actual size to ensure clients
1132 // of this operation truly observe size, not capacity!
1133 Location loc = op.getLoc();
1134 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1135 op.getTensor().getType());
1136 auto mem = desc.getValMemRef();
1137 auto size = desc.getValMemSize(rewriter, loc);
1138 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1139 return success();
1140 }
1141};
1142
1143/// Sparse codegen rule for the convert operator.
1144class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1145public:
1146 using OpConversionPattern::OpConversionPattern;
1147 LogicalResult
1148 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter) const override {
1150 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1151 SparseTensorEncodingAttr encSrc =
1152 getSparseTensorEncoding(op.getSource().getType());
1153 // The output tensor can not be a slice and those cases should have been
1154 // rejected by ConvertOp::verify() already.
1155 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1156 // Different encoding (except for different bitwidth) should be handled by
1157 // rewriting.
1158 // We need further rewrites if the input tensor is a slice too.
1159 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1160 encSrc.isSlice()) {
1161 return failure();
1162 }
1163
1164 Type retElemTp = op.getResult().getType().getElementType();
1165 Type srcElemTp = op.getSource().getType().getElementType();
1166 // Fold the trivial cases.
1167 if (retElemTp == srcElemTp && encDst == encSrc) {
1168 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1169 return success();
1170 }
1171 //
1172 // Do element-wise type conversion without using InsertOp.
1173 //
1174 // for each memref in srcTensor:
1175 // dst = memref.alloc
1176 // if srcMemRefType != dstMemRefType:
1177 // for every dst[i] = cast(src[i])
1178 // else:
1179 // dst = memref.copy(src)
1180 Location loc = op.getLoc();
1181 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1182 op.getSource().getType());
1183 SmallVector<Value> fields;
1185 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1186 [&rewriter, &fields, srcDesc,
1187 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1188 LevelType /*lt*/) -> bool {
1189 // Simply reuses the storage specifier as it is an SSA value.
1190 if (fKind == SparseTensorFieldKind::StorageSpec) {
1191 fields.push_back(srcDesc.getSpecifier());
1192 } else {
1193 // Allocates new memrefs
1194 Value srcMem = srcDesc.getMemRefField(fIdx);
1195 // TODO: We can instead use the actual memSize in specifier, that
1196 // would require a subViewOp to avoid overflow when copying
1197 // values.
1198 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1199 auto dstMem = memref::AllocOp::create(rewriter, loc,
1200 cast<MemRefType>(fTp), sz);
1201 if (fTp != srcMem.getType()) {
1202 // Converts elements type.
1203 scf::buildLoopNest(
1204 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1205 constantIndex(rewriter, loc, 1),
1206 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1207 ValueRange ivs) {
1208 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1209 Value casted = genCast(builder, loc, v,
1210 dstMem.getType().getElementType());
1211 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1212 });
1213 } else {
1214 // TODO: We can even reuse the same memref for the new tensor,
1215 // but that requires a `ref-counting` based memory management
1216 // for shared memrefs between multiple sparse tensors.
1217 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1218 }
1219 fields.push_back(dstMem);
1220 }
1221 return true;
1222 });
1223
1224 rewriter.replaceOpWithMultiple(op, {fields});
1225 return success();
1226 }
1227};
1228
1229class SparseExtractSliceConverter
1230 : public OpConversionPattern<tensor::ExtractSliceOp> {
1231public:
1232 using OpConversionPattern::OpConversionPattern;
1233 LogicalResult
1234 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1235 ConversionPatternRewriter &rewriter) const override {
1236 Location loc = op.getLoc();
1237 MLIRContext *ctx = op.getContext();
1238 auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1239 auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1240 // TODO: We should check these in ExtractSliceOp::verify.
1241 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1242 return failure();
1243 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1244
1245 SmallVector<Value> fields;
1246 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1247 op.getSource().getType());
1248
1249 auto newSpec = StorageSpecifierInitOp::create(
1250 rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1251 desc.getSpecifier());
1252 desc.setSpecifier(newSpec);
1253
1254 // Fills in slice information.
1255 for (auto [idx, offset, size, stride] : llvm::enumerate(
1256 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1257 Dimension dim = idx;
1258
1259 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1260 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1261 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1262 // TODO: We could probably only set dynamic value here. But it would
1263 // requires us to fill the hole when casting a static slice to dynamic
1264 // slice.
1265 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1266 dim, offsetV);
1267
1268 // FIXME: we need to distinguish level sizes and dimension size for slices
1269 // here. Maybe we should store slice level sizes in a different array
1270 // instead of reusing it.
1271 assert(srcEnc.isIdentity());
1272 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1273 sizeV);
1274 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1275 dim, strideV);
1276 }
1277
1278 // NOTE: we can not generate tuples directly from descriptor here, as the
1279 // descriptor is holding the original type, yet we want the slice type
1280 // here (they shared every memref but with an updated specifier).
1281 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1282 return success();
1283 }
1284};
1285
1286/// Sparse codegen rule for number of entries operator.
1287class SparseNumberOfEntriesConverter
1288 : public OpConversionPattern<NumberOfEntriesOp> {
1289public:
1290 using OpConversionPattern::OpConversionPattern;
1291 LogicalResult
1292 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1293 ConversionPatternRewriter &rewriter) const override {
1294 // Query memSizes for the actually stored values.
1295 // FIXME: the nse value computed in this way might be wrong when there is
1296 // any "loose_compressed" level.
1297 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1298 op.getTensor().getType());
1299 rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1300 return success();
1301 }
1302};
1303
1304struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1305 using OpConversionPattern::OpConversionPattern;
1306 LogicalResult
1307 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1308 ConversionPatternRewriter &rewriter) const override {
1309 Location loc = op.getLoc();
1310 const auto stt = getSparseTensorType(op.getResult());
1311
1312 SmallVector<Value> fields;
1313
1315 stt,
1316 [&rewriter, &fields, &op, &stt,
1317 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1318 Level /*lvl*/, LevelType lt) -> bool {
1319 assert(fields.size() == fIdx);
1320 if (fKind == SparseTensorFieldKind::StorageSpec) {
1321 fields.push_back(
1322 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1323 } else {
1324 // Else simply takes the inputs.
1325 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1326 ? op.getValues()
1327 : op.getLevels()[fIdx];
1328 // TODO: handle batch.
1329 TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1330 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1331 // Flattens the buffer to batchLvlRank.
1332 auto reassoc = getReassociationForFlattening(
1333 mem.getType(), stt.getBatchLvlRank());
1334 mem = memref::CastOp::create(
1335 rewriter, loc, fType,
1336 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1337 } else {
1338 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1339 }
1340 fields.push_back(mem);
1341 }
1342 return true;
1343 });
1344
1345 MutSparseTensorDescriptor desc(stt, fields);
1346 Value c0 = constantIndex(rewriter, loc, 0);
1347 Value c1 = constantIndex(rewriter, loc, 1);
1348 Value c2 = constantIndex(rewriter, loc, 2);
1349 Value posBack = c0; // index to the last value in the position array
1350 Value memSize = c1; // memory size for current array
1351
1352 Level trailCOOStart = stt.getAoSCOOStart();
1353 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1354 // Sets up SparseTensorSpecifier.
1355 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1356 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1357
1358 // Sets up the level size.
1359 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1360 desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1361 // We use a single AOS array to store the trailing COO, so there is only
1362 // one memory size to set for the entire COO section.
1363 if (lvl > trailCOOStart)
1364 continue;
1365
1366 // Sets up the memory size by reading the last value in position array.
1367 LevelType lt = stt.getLvlType(lvl);
1368 // Simply forwards the position index when this is a dense level.
1369 if (lt.isa<LevelFormat::Dense>()) {
1370 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1371 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1372 continue;
1373 }
1374 if (lt.isa<LevelFormat::Batch>()) {
1375 // Skips batch levels as it is not linearized.
1376 // FIXME: this assumes that every batch has the same number of nse, need
1377 // to be generalized to handle varied-size batches.
1378 continue;
1379 }
1380
1381 if (isWithPosLT(lt)) {
1382 assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1383 if (isLooseCompressedLT(lt)) {
1384 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1385 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1386 } else {
1387 assert(isCompressedLT(lt));
1388 posBack = memSize;
1389 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1390 }
1391 desc.setPosMemSize(rewriter, loc, lvl, memSize);
1392 // The last value in position array is the memory size for next level.
1393 // FIXME: this assumes that every batch has the same number of nse, need
1394 // to be generalized to handle varied-size batches.
1395 SmallVector<Value> batched(stt.getBatchLvlRank(),
1396 constantIndex(rewriter, loc, 0));
1397 batched.push_back(posBack);
1398 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1399 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1400 }
1401 assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1402 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1403 if (lvl == trailCOOStart) {
1404 Value cooSz = arith::MulIOp::create(
1405 rewriter, loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1406 desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1407 } else {
1408 desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1409 }
1410 }
1411 desc.setValMemSize(rewriter, loc, memSize);
1412
1413 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1414 return success();
1415 }
1416};
1417
1418struct SparseDisassembleOpConverter
1419 : public OpConversionPattern<DisassembleOp> {
1420 using OpConversionPattern::OpConversionPattern;
1421 SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1422 MLIRContext *context)
1423 : OpConversionPattern(typeConverter, context) {}
1424
1425 LogicalResult
1426 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1427 ConversionPatternRewriter &rewriter) const override {
1428 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1429 op.getTensor().getType());
1430 Location loc = op.getLoc();
1431 SmallVector<Value> retMem;
1432 SmallVector<Value> retLen;
1433 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1434 &retLen](FieldIndex fid,
1436 Level lvl, LevelType lt) -> bool {
1437 if (fKind == SparseTensorFieldKind::StorageSpec)
1438 return true;
1439 SparseTensorType stt(desc.getRankedTensorType());
1440 Value sz, src;
1442 if (fKind == SparseTensorFieldKind::ValMemRef) {
1443 sz = desc.getValMemSize(rewriter, loc);
1444 src = desc.getValMemRef();
1445 dst = genToMemref(rewriter, loc, op.getOutValues());
1446
1447 retMem.push_back(dst);
1448 Type valLenTp = op.getValLen().getType();
1449 retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1450 } else {
1451 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1452 fKind == SparseTensorFieldKind::CrdMemRef);
1453
1454 sz = fKind == SparseTensorFieldKind::PosMemRef
1455 ? desc.getPosMemSize(rewriter, loc, lvl)
1456 : desc.getCrdMemSize(rewriter, loc, lvl);
1457 src = desc.getMemRefField(fid);
1458 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1459 retMem.push_back(dst);
1460 // Retrieves the corresponding level length type.
1461 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1462 retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1463 }
1464 Value flatOut = dst;
1465 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1466 auto reassoc =
1467 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1468 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1469 }
1470 Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1471 Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1472 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1473 return true;
1474 });
1475
1476 // Converts MemRefs back to Tensors.
1477 SmallVector<Value> retValues =
1478 llvm::map_to_vector(retMem, [&rewriter, loc](Value v) -> Value {
1479 return bufferization::ToTensorOp::create(
1480 rewriter, loc, memref::getTensorTypeFromMemRefType(v.getType()),
1481 v);
1482 });
1483 // Appends the actual memory length used in each buffer returned.
1484 retValues.append(retLen.begin(), retLen.end());
1485 rewriter.replaceOp(op, retValues);
1486 return success();
1487 }
1488};
1489
1490struct SparseNewConverter : public OpConversionPattern<NewOp> {
1491 using OpConversionPattern::OpConversionPattern;
1492 LogicalResult
1493 matchAndRewrite(NewOp op, OpAdaptor adaptor,
1494 ConversionPatternRewriter &rewriter) const override {
1495 Location loc = op.getLoc();
1496 const auto dstTp = getSparseTensorType(op.getResult());
1497 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1498 // are handled by rewriting.
1499 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1500 return failure();
1501
1502 // Implement as follows:
1503 // %reader = @createCheckedSparseTensorReader(%filename)
1504 // %nse = @getSparseTensorNSE(%reader)
1505 // %coo = bufferization.alloc_tensor an ordered COO with
1506 // dst dim ordering, size_hint = %nse
1507 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1508 // %values = sparse_tensor.values(%coo)
1509 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1510 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1511 // update storage specifier
1512 // @delSparseTensorReader(%reader)
1513 SmallVector<Value> dimSizesValues;
1514 Value dimSizesBuffer;
1515 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1516 dimSizesValues, dimSizesBuffer);
1517
1518 // Get the number of stored entries.
1519 const Type indexTp = rewriter.getIndexType();
1520 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1521 {indexTp}, {reader}, EmitCInterface::Off)
1522 .getResult(0);
1523
1524 // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1525 SmallVector<Value> lvlSizesValues;
1526 Value dim2lvlBuffer;
1527 Value lvl2dimBuffer;
1528 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1529 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1530
1531 // Construct allocation for each field.
1532 Value sizeHint = nse;
1533 SmallVector<Value> fields;
1534 createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1535 lvlSizesValues, fields);
1536
1537 // Read the COO tensor data.
1538 MutSparseTensorDescriptor desc(dstTp, fields);
1539 Value xs = desc.getAOSMemRef();
1540 Value ys = desc.getValMemRef();
1541 const Type boolTp = rewriter.getIntegerType(1);
1542 const Type elemTp = dstTp.getElementType();
1543 const Type crdTp = dstTp.getCrdType();
1544 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1547 Value isSorted =
1548 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1549 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1550 EmitCInterface::On)
1551 .getResult(0);
1552
1553 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1554 // data if the input elements aren't sorted yet.
1555 const Level lvlRank = dstTp.getLvlRank();
1556 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1557 Value kFalse = constantI1(rewriter, loc, false);
1558 Value notSorted = arith::CmpIOp::create(
1559 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1560 scf::IfOp ifOp =
1561 scf::IfOp::create(rewriter, loc, notSorted, /*else*/ false);
1562 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1563 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1564 SortOp::create(rewriter, loc, nse, xs, ValueRange{ys}, xPerm,
1565 rewriter.getIndexAttr(0),
1566 SparseTensorSortKind::HybridQuickSort);
1567 rewriter.setInsertionPointAfter(ifOp);
1568 }
1569
1570 // Set PosMemRef0[1] = nse.
1571 const Value c1 = constantIndex(rewriter, loc, 1);
1572 const Value posMemref0 = desc.getPosMemRef(0);
1573 const Type posTp = dstTp.getPosType();
1574 const Value posNse = genCast(rewriter, loc, nse, posTp);
1575 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1576
1577 // Update storage specifier.
1578 Value coordinatesSize = arith::MulIOp::create(
1579 rewriter, loc, nse, constantIndex(rewriter, loc, lvlRank));
1580 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1581 coordinatesSize);
1582 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1583 std::nullopt, nse);
1584
1585 // Release the sparse tensor reader.
1586 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1587 EmitCInterface::Off);
1588
1589 // Replace operation with resulting memrefs.
1590 rewriter.replaceOpWithMultiple(op, {fields});
1591 return success();
1592 }
1593};
1594
1595struct SparseHasRuntimeLibraryConverter
1596 : public OpConversionPattern<HasRuntimeLibraryOp> {
1597 using OpConversionPattern::OpConversionPattern;
1598 LogicalResult
1599 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1600 ConversionPatternRewriter &rewriter) const override {
1601 auto i1Type = rewriter.getI1Type();
1602 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1603 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1604 return success();
1605 }
1606};
1607
1608} // namespace
1609
1610//===----------------------------------------------------------------------===//
1611// Public method for populating conversion rules.
1612//===----------------------------------------------------------------------===//
1613
1614/// Populates the given patterns list with conversion rules required for
1615/// the sparsification of linear algebra operations.
1617 const TypeConverter &typeConverter, RewritePatternSet &patterns,
1618 bool createSparseDeallocs, bool enableBufferInitialization) {
1619 patterns.add<
1620 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1621 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1622 SparseCastConverter, SparseExtractSliceConverter,
1623 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1624 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1625 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1626 StorageSpecifierKind::DimOffset>,
1627 SparseSliceGetterOpConverter<ToSliceStrideOp,
1628 StorageSpecifierKind::DimStride>,
1629 SparseToPositionsConverter, SparseToCoordinatesConverter,
1630 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1631 SparseConvertConverter, SparseNewConverter,
1632 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1633 typeConverter, patterns.getContext());
1634 patterns.add<SparseTensorDeallocConverter>(
1635 typeConverter, patterns.getContext(), createSparseDeallocs);
1636 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1637 typeConverter, patterns.getContext(), enableBufferInitialization);
1638}
return success()
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
bool isCompressedLT(LevelType lt)
Definition Enums.h:415
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition Enums.h:418
unsigned FieldIndex
The type of field indices.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
Definition Enums.h:413
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326