MLIR 23.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor types and primitives to actual compiler
10// visible buffers and actual compiler IR that implements these primitives on
11// the selected sparse tensor storage schemes. This pass provides an alternative
12// to the SparseTensorConversion pass, eliminating the dependence on a runtime
13// support library (other than for file I/O), and providing many more
14// opportunities for subsequent compiler optimization of the generated code.
15//
16//===----------------------------------------------------------------------===//
17
18#include "Utils/CodegenUtils.h"
20
32#include "llvm/ADT/SmallVectorExtras.h"
33
34#include <optional>
35
36using namespace mlir;
37using namespace mlir::sparse_tensor;
38
39//===----------------------------------------------------------------------===//
40// Helper methods.
41//===----------------------------------------------------------------------===//
42
43/// Flatten the given value ranges into a single vector of values.
46 for (const auto &vals : values)
47 llvm::append_range(result, vals);
48 return result;
49}
50
51/// Generates a load with proper `index` typing.
52static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
53 idx = genCast(builder, loc, idx, builder.getIndexType());
54 return memref::LoadOp::create(builder, loc, mem, idx);
55}
56
57/// Generates a store with proper `index` typing and proper value.
58static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
59 Value idx) {
60 idx = genCast(builder, loc, idx, builder.getIndexType());
61 val = genCast(builder, loc, val,
62 cast<ShapedType>(mem.getType()).getElementType());
63 memref::StoreOp::create(builder, loc, val, mem, idx);
64}
65
66/// Creates a straightforward counting for-loop.
67static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
69 Value lower = Value()) {
70 Type indexType = builder.getIndexType();
71 if (!lower)
72 lower = constantZero(builder, loc, indexType);
73 Value one = constantOne(builder, loc, indexType);
74 scf::ForOp forOp =
75 scf::ForOp::create(builder, loc, lower, upper, one, fields);
76 for (unsigned i = 0, e = fields.size(); i < e; i++)
77 fields[i] = forOp.getRegionIterArg(i);
78 builder.setInsertionPointToStart(forOp.getBody());
79 return forOp;
80}
81
82/// Creates a push back operation.
83static void createPushback(OpBuilder &builder, Location loc,
85 SparseTensorFieldKind kind, std::optional<Level> lvl,
86 Value value, Value repeat = Value()) {
87 Type etp = desc.getMemRefElementType(kind, lvl);
88 Value field = desc.getMemRefField(kind, lvl);
89 StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
90
91 auto pushBackOp = PushBackOp::create(
92 builder, loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl),
93 field, genCast(builder, loc, value, etp), repeat);
94
95 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
96 desc.setSpecifierField(builder, loc, specFieldKind, lvl,
97 pushBackOp.getNewSize());
98}
99
100/// Generates code that allocates a sparse storage scheme for given rank.
101static void allocSchemeForRank(OpBuilder &builder, Location loc,
102 MutSparseTensorDescriptor desc, Level startLvl) {
103 const SparseTensorType stt(desc.getRankedTensorType());
104 Value linear = constantIndex(builder, loc, 1);
105 const Level lvlRank = stt.getLvlRank();
106 for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
107 const auto lt = stt.getLvlType(lvl);
108 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
109 // Append linear x positions, initialized to zero. Since each compressed
110 // dimension initially already has a single zero entry, this maintains
111 // the desired "linear + 1" length property at all times. For loose
112 // compression, we multiply linear by two in order to append both the
113 // lo/hi positions.
114 Value posZero = constantZero(builder, loc, stt.getPosType());
115 if (isLooseCompressedLT(lt)) {
116 Value two = constantIndex(builder, loc, 2);
117 linear = arith::MulIOp::create(builder, loc, linear, two);
118 }
119 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
120 /*value=*/posZero, /*repeat=*/linear);
121 return;
122 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
123 return; // nothing to do
124 }
125 // Keep compounding the size, but nothing needs to be initialized
126 // at this level. We will eventually reach a compressed level or
127 // otherwise the values array for the from-here "all-dense" case.
128 assert(isDenseLT(lt));
129 Value size = desc.getLvlSize(builder, loc, lvl);
130 linear = arith::MulIOp::create(builder, loc, linear, size);
131 }
132 // Reached values array so prepare for an insertion.
133 Value valZero = constantZero(builder, loc, stt.getElementType());
135 std::nullopt, /*value=*/valZero, /*repeat=*/linear);
136}
137
138/// Creates allocation operation.
140 MemRefType memRefType, Value sz,
141 bool enableInit) {
142 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
143 Type elemType = memRefType.getElementType();
144 if (enableInit) {
145 Value fillValue = constantZero(builder, loc, elemType);
146 linalg::FillOp::create(builder, loc, fillValue, buffer);
147 }
148 return buffer;
149}
150
151/// Creates the dim sizes array, filling in from dynamic sizes.
152static void createDimSizes(OpBuilder &builder, Location loc,
153 SparseTensorType stt, ValueRange dynSizes,
154 /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
155 const Dimension dimRank = stt.getDimRank();
156 dimSizesValues.clear();
157 dimSizesValues.reserve(dimRank);
158 unsigned i = 0;
159 for (const Size sz : stt.getDimShape())
160 dimSizesValues.push_back(ShapedType::isDynamic(sz)
161 ? dynSizes[i++]
162 : constantIndex(builder, loc, sz));
163}
164
165/// Creates allocation for each field in sparse tensor type. Note that
166/// for all dynamic memrefs in the sparse tensor stroage layout, the
167/// memory size is really the capacity of the "vector", while the actual
168/// size resides in the sizes array.
169static void createAllocFields(OpBuilder &builder, Location loc,
170 SparseTensorType stt, bool enableInit,
171 Value sizeHint,
172 SmallVectorImpl<Value> &lvlSizesValues,
173 /*out*/ SmallVectorImpl<Value> &fields) {
174 Level lvlRank = stt.getLvlRank();
175 // Set up some heuristic sizes. We try to set the initial
176 // size based on available information. Otherwise we just
177 // initialize a few elements to start the reallocation chain.
178 // TODO: refine this
179 Value posHeuristic, crdHeuristic, valHeuristic;
180 if (stt.isAllDense()) {
181 valHeuristic = lvlSizesValues[0];
182 for (Level lvl = 1; lvl < lvlRank; lvl++)
183 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
184 lvlSizesValues[lvl]);
185 } else if (sizeHint) {
186 if (stt.getAoSCOOStart() == 0) {
187 posHeuristic = constantIndex(builder, loc, 2);
188 crdHeuristic = arith::MulIOp::create(
189 builder, loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
190 } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
191 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
192 constantIndex(builder, loc, 1));
193 crdHeuristic = sizeHint;
194 } else {
195 posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
196 }
197 valHeuristic = sizeHint;
198 } else {
199 posHeuristic = crdHeuristic = valHeuristic =
200 constantIndex(builder, loc, 16);
201 }
202 // Initializes all fields. An initial storage specifier and allocated
203 // positions/coordinates/values memrefs (with heuristic capacity).
205 stt,
206 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
207 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
208 Level /*lvl*/, LevelType /*lt*/) -> bool {
209 assert(fields.size() == fIdx);
210 Value field;
211 switch (fKind) {
213 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
214 break;
216 field = createAllocation(builder, loc, cast<MemRefType>(fType),
217 posHeuristic, enableInit);
218 break;
220 field = createAllocation(builder, loc, cast<MemRefType>(fType),
221 crdHeuristic, enableInit);
222 break;
224 field = createAllocation(builder, loc, cast<MemRefType>(fType),
225 valHeuristic, enableInit);
226 break;
227 }
228 assert(field);
229 fields.push_back(field);
230 // Returns true to continue the iteration.
231 return true;
232 });
233 // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
234 // and gives all position fields an initial zero entry, so that it is
235 // easier to maintain the "linear + 1" length property.
236 MutSparseTensorDescriptor desc(stt, fields);
237 Value posZero = constantZero(builder, loc, stt.getPosType());
238 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
239 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
240 const auto lt = stt.getLvlType(lvl);
241 if (isCompressedLT(lt) || isLooseCompressedLT(lt))
242 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
243 /*value=*/posZero);
244 }
245 allocSchemeForRank(builder, loc, desc, /*rank=*/0);
246}
247
248/// Helper method that generates block specific to compressed case:
249///
250/// // given: parentPos = posCursor[lvl-1]
251/// pstart = desc.positions[lvl][parentPos]
252/// pstop = desc.positions[lvl][parentPos+1]
253/// plast = pstop - 1
254/// msz = desc.coordinates[lvl].size()
255/// if (pstart < pstop) {
256/// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
257/// } else { // first insertion
258/// isPresent = false
259/// desc.positions[lvl][parentPos] = msz
260/// }
261/// if (isPresent) { // coordinate is already present
262/// pnext = plast
263/// } else {
264/// desc.coordinates[lvl].push_back(lvlCoords[lvl])
265/// desc.positions[lvl][parentPos+1] = msz+1
266/// pnext = msz
267/// <prepare level lvl+1>
268/// }
269/// posCursor[lvl] = pnext
272 Value /*unused*/, Value parentPos, Level lvl) {
273 const SparseTensorType stt(desc.getRankedTensorType());
274 const Level lvlRank = stt.getLvlRank();
275 assert(lvl < lvlRank && "Level is out of bounds");
276 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
277 "Level-rank mismatch");
278 SmallVector<Type> types;
279 Type indexType = builder.getIndexType();
280 Type boolType = builder.getIntegerType(1);
281 unsigned crdFidx;
282 unsigned crdStride;
283 std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
284 const Value one = constantIndex(builder, loc, 1);
285 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
286 const Value positionsAtLvl = desc.getPosMemRef(lvl);
287 const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
288 const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
289 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
290 const Value crdStrideC =
291 crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
292 const Value msz =
293 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
294 : crdMsz;
295 const Value plast = arith::SubIOp::create(
296 builder, loc, genCast(builder, loc, pstop, indexType), one);
297 // Conditional expression.
298 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
299 pstart, pstop);
300 types.push_back(boolType);
301 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt, /*else*/ true);
302 types.pop_back();
303 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
304 Value crd = genLoad(
305 builder, loc, desc.getMemRefField(crdFidx),
306 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
307 : plast);
308 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
309 genCast(builder, loc, crd, indexType),
310 lvlCoords[lvl]);
311 scf::YieldOp::create(builder, loc, eq);
312 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
313 if (lvl > 0)
314 genStore(builder, loc, msz, positionsAtLvl, parentPos);
315 scf::YieldOp::create(builder, loc, constantI1(builder, loc, false));
316 builder.setInsertionPointAfter(ifOp1);
317 // If present construct. Note that for a non-unique dimension level, we
318 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
319 //
320 // TODO: generate less temporary IR?
321 //
322 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
323 types.push_back(desc.getField(i).getType());
324 types.push_back(indexType);
325 const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
326 : constantI1(builder, loc, false);
327 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p, /*else*/ true);
328 // If present (fields unaffected, update pnext to plast).
329 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
330
331 // FIXME: This does not looks like a clean way, but probably the most
332 // efficient way.
333 desc.getFields().push_back(plast);
334 scf::YieldOp::create(builder, loc, desc.getFields());
335 desc.getFields().pop_back();
336
337 // If !present (changes fields, update pnext).
338 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
339 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
340 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
341 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
342 /*value=*/lvlCoords[lvl]);
343 // Prepare the next level "as needed".
344 if ((lvl + 1) < lvlRank)
345 allocSchemeForRank(builder, loc, desc, lvl + 1);
346
347 desc.getFields().push_back(msz);
348 scf::YieldOp::create(builder, loc, desc.getFields());
349 desc.getFields().pop_back();
350
351 // Update fields and return next pos.
352 builder.setInsertionPointAfter(ifOp2);
353 unsigned o = 0;
354 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
355 desc.setField(i, ifOp2.getResult(o++));
356 return ifOp2.getResult(o);
357}
358
359/// Generates insertion finalization code.
360static void genEndInsert(OpBuilder &builder, Location loc,
362 const SparseTensorType stt(desc.getRankedTensorType());
363 const Level lvlRank = stt.getLvlRank();
364 for (Level lvl = 0; lvl < lvlRank; lvl++) {
365 const auto lt = stt.getLvlType(lvl);
366 if (isCompressedLT(lt)) {
367 // Compressed dimensions need a position cleanup for all entries
368 // that were not visited during the insertion pass.
369 //
370 // TODO: avoid cleanup and keep compressed scheme consistent at all
371 // times?
372 //
373 if (lvl > 0) {
374 Type posType = stt.getPosType();
375 Value posMemRef = desc.getPosMemRef(lvl);
376 Value hi = desc.getPosMemSize(builder, loc, lvl);
377 Value zero = constantIndex(builder, loc, 0);
378 Value one = constantIndex(builder, loc, 1);
379 // Vector of only one, but needed by createFor's prototype.
380 SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
381 scf::ForOp loop = createFor(builder, loc, hi, inits, one);
382 Value i = loop.getInductionVar();
383 Value oldv = loop.getRegionIterArg(0);
384 Value newv = genLoad(builder, loc, posMemRef, i);
385 Value posZero = constantZero(builder, loc, posType);
386 Value cond = arith::CmpIOp::create(
387 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
388 scf::IfOp ifOp = scf::IfOp::create(builder, loc, TypeRange(posType),
389 cond, /*else*/ true);
390 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
391 genStore(builder, loc, oldv, posMemRef, i);
392 scf::YieldOp::create(builder, loc, oldv);
393 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
394 scf::YieldOp::create(builder, loc, newv);
395 builder.setInsertionPointAfter(ifOp);
396 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
397 builder.setInsertionPointAfter(loop);
398 }
399 } else {
400 assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
401 isNOutOfMLT(lt));
402 }
403 }
404}
405
406/// Generates a subview into the sizes.
407static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
408 Value sz) {
409 auto memTp = llvm::cast<MemRefType>(mem.getType());
410 // For higher-dimensional memrefs, we assume that the innermost
411 // dimension is always of the right size.
412 // TODO: generate complex truncating view here too?
413 if (memTp.getRank() > 1)
414 return mem;
415 // Truncate linear memrefs to given size.
416 return memref::SubViewOp::create(
417 builder, loc,
418 MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
419 mem, ValueRange{}, ValueRange{sz}, ValueRange{},
420 ArrayRef<int64_t>{0}, // static offset
421 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
422 ArrayRef<int64_t>{1}) // static stride
423 .getResult();
424}
425
426/// Creates the reassociation array.
428getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
429 SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
430 // Create reassociation in the form:
431 // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
432 for (unsigned i = 0; i < batchLvls; i++)
433 ret[i].push_back(i);
434
435 for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
436 ret.back().push_back(i);
437
438 return ret;
439}
440
441//===----------------------------------------------------------------------===//
442// Codegen rules.
443//===----------------------------------------------------------------------===//
444
445namespace {
446
447/// Helper class to help lowering sparse_tensor.insert operation.
448class SparseInsertGenerator
449 : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
450public:
451 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
452 bool genCall)
453 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
454
455 /// Generates code along an insertion path without the need for a "cursor".
456 /// This current insertion strategy comes at the expense of some testing
457 /// overhead for each insertion. The strategy will be optimized later for
458 /// common insertion patterns. The current insertion strategy also assumes
459 /// insertions occur in "a reasonable order" that enables building the
460 /// storage scheme in an appending/inserting kind of fashion (i.e. no
461 /// in-between insertions that need data movement). The implementation
462 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
463 ///
464 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
465 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
466 OpBuilder &builder, Location loc) {
467 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
468 const Level lvlRank = stt.getLvlRank();
469 // Extract fields and coordinates from args.
470 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
471 MutSparseTensorDescriptor desc(stt, fields);
472 const SmallVector<Value> coords =
473 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
474 Value value = args.back();
475 Value parentPos = constantZero(builder, loc, builder.getIndexType());
476 // Generate code for every level.
477 for (Level lvl = 0; lvl < lvlRank; lvl++) {
478 const auto lt = stt.getLvlType(lvl);
479 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
480 // Create:
481 // if (!present) {
482 // coordinates[lvl].push_back(coords[lvl])
483 // <update positions and prepare level lvl + 1>
484 // }
485 // positions[lvl] = coordinates.size() - 1
486 // <insert @ positions[lvl] at next level lvl + 1>
487 if (isLooseCompressedLT(lt)) {
488 Value two = constantIndex(builder, loc, 2);
489 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
490 }
491 parentPos =
492 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
493 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
494 // Create:
495 // coordinates[lvl].push_back(coords[lvl])
496 // positions[lvl] = positions[lvl-1]
497 // <insert @ positions[lvl] at next level lvl + 1>
498 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
499 lvl, /*value=*/coords[lvl]);
500 } else {
501 assert(isDenseLT(lt));
502 // Construct the new position as:
503 // positions[lvl] = size * positions[lvl-1] + coords[lvl]
504 // <insert @ positions[lvl] at next level lvl + 1>
505 Value size = desc.getLvlSize(builder, loc, lvl);
506 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
507 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
508 }
509 }
510 // Reached the actual value append/insert.
511 if (!stt.isDenseLvl(lvlRank - 1))
512 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
513 std::nullopt, value);
514 else
515 genStore(builder, loc, value, desc.getValMemRef(), parentPos);
516 return fields;
517 }
518
519 std::string getMangledFuncName() {
520 // The mangled name of the function has this format:
521 // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
522 constexpr const char kInsertFuncNamePrefix[] = "_insert_";
523 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
524 SmallString<32> nameBuffer;
525 llvm::raw_svector_ostream nameOstream(nameBuffer);
526 nameOstream << kInsertFuncNamePrefix;
527 const Level lvlRank = stt.getLvlRank();
528 for (Level l = 0; l < lvlRank; l++) {
529 std::string lvlType = toMLIRString(stt.getLvlType(l));
530 // Replace/remove punctuations in level properties.
531 std::replace_if(
532 lvlType.begin(), lvlType.end(),
533 [](char c) { return c == '(' || c == ','; }, '_');
534 llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
535 nameOstream << lvlType << "_";
536 }
537 // Static dim sizes are used in the generated code while dynamic sizes are
538 // loaded from the dimSizes buffer. This is the reason for adding the shape
539 // to the function name.
540 for (const auto sz : stt.getDimShape())
541 nameOstream << sz << "_";
542 // Permutation information is also used in generating insertion.
543 if (!stt.isIdentity())
544 nameOstream << stt.getDimToLvl() << "_";
545 nameOstream << stt.getElementType() << "_";
546 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
547 return nameOstream.str().str();
548 }
549
550private:
551 TensorType rtp;
552};
553
554/// Sparse tensor storage conversion rule for returns.
555class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
556public:
557 using OpConversionPattern::OpConversionPattern;
558 LogicalResult
559 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override {
561 // Create a return with the flattened value extracted from sparse tensors.
562 rewriter.replaceOpWithNewOp<func::ReturnOp>(
563 op, flattenValues(adaptor.getOperands()));
564 return success();
565 }
566};
567
568/// Sparse tensor storage conversion rule for calls.
569class SparseCallConverter : public OpConversionPattern<func::CallOp> {
570public:
571 // The default CallOp converter can not handle 1:N type conversion.
572 using OpConversionPattern::OpConversionPattern;
573 LogicalResult
574 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter) const override {
576 Location loc = op.getLoc();
577 // In case of:
578 // sparse_tensor, f, sparse_tensor = call @foo(...)
579 // ==>
580 // memref..., f, memref = call @foo(...) replace with
581 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
582 SmallVector<Type> finalRetTy;
583 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
584 return failure();
585
586 // (1) Generates new call with flattened return value.
587 auto newCall =
588 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
589 flattenValues(adaptor.getOperands()));
590 // (2) Gather sparse tensor returns.
591 SmallVector<SmallVector<Value>> packedResultVals;
592 // Tracks the offset of current return value (of the original call)
593 // relative to the new call (after sparse tensor flattening);
594 unsigned retOffset = 0;
595 // Temporal buffer to hold the flattened list of type for
596 // a sparse tensor.
597 SmallVector<Type> sparseFlat;
598 for (auto ret : op.getResults()) {
599 assert(retOffset < newCall.getNumResults());
600 auto retType = ret.getType();
601 if (failed(typeConverter->convertType(retType, sparseFlat)))
602 llvm_unreachable("Failed to convert type in sparse tensor codegen");
603
604 // Converted types can not be empty when the type conversion succeed.
605 assert(!sparseFlat.empty());
606 if (sparseFlat.size() > 1) {
607 auto flatSize = sparseFlat.size();
608 packedResultVals.emplace_back();
609 llvm::append_range(packedResultVals.back(),
610 newCall.getResults().slice(retOffset, flatSize));
611 retOffset += flatSize;
612 } else {
613 // If this is an 1:1 conversion, no need for casting.
614 packedResultVals.emplace_back();
615 packedResultVals.back().push_back(newCall.getResult(retOffset));
616 retOffset++;
617 }
618 sparseFlat.clear();
619 }
620
621 assert(packedResultVals.size() == op.getNumResults());
622 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
623 return success();
624 }
625};
626
627/// Sparse codegen rule for level accesses.
628class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
629public:
630 using OpConversionPattern::OpConversionPattern;
631 LogicalResult
632 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 std::optional<int64_t> lvl = op.getConstantLvlIndex();
635 RankedTensorType srcType = op.getSource().getType();
636 if (!lvl || !getSparseTensorEncoding(srcType))
637 return failure();
638
639 auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
640 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
641
642 rewriter.replaceOp(op, sz);
643 return success();
644 }
645};
646
647// TODO: use a new SortCOO operation here instead of reusing convert op.
648struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
649 using OpConversionPattern::OpConversionPattern;
650 LogicalResult
651 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
652 ConversionPatternRewriter &rewriter) const override {
653 Location loc = op.getLoc();
654 MLIRContext *ctx = op.getContext();
655
656 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
657 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
658
659 // Should have been verified.
660 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
661 dstStt.isCOOType() && srcStt.isCOOType());
662 assert(dstStt.hasSameDimToLvl(srcStt));
663
664 // We don't need a mutable descriptor here as we perform sorting in-place.
665 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
666 op.getInputCoo().getType());
667 auto nnz = desc.getValMemSize(rewriter, op.getLoc());
668 auto crd = desc.getAOSMemRef();
669 auto val = desc.getValMemRef();
670
671 // Otherwise we need another data shuffle and a non-identity map.
672 assert(dstStt.hasSameDimToLvl(srcStt));
673 (void)dstStt; // to silence warning when assertion is disabled
674
675 auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
676
677 SortOp::create(rewriter, loc, nnz, crd, ValueRange{val}, id,
678 rewriter.getIndexAttr(0), op.getAlgorithm());
679
680 // Since we do in-place sorting, the destinate tensor will have the same set
681 // of memrefs as the source tensor.
682 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
683 return success();
684 }
685};
686
687template <typename Op, StorageSpecifierKind kind>
688class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
689public:
690 using OpConversionPattern<Op>::OpConversionPattern;
691 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
692
693 LogicalResult
694 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter) const override {
696 // Simply lowers to specifer.get <field> operation.
697 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
698 op.getSlice().getType());
699 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
700 op.getDim().getZExtValue());
701
702 rewriter.replaceOp(op, v);
703 return success();
704 }
705};
706
707/// Sparse codegen rule for trivial tensor casts.
708class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
709public:
710 using OpConversionPattern::OpConversionPattern;
711 LogicalResult
712 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
713 ConversionPatternRewriter &rewriter) const override {
714 // Only rewrite identically annotated source/dest.
715 auto encDst = getSparseTensorEncoding(op.getType());
716 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
717 if (!encDst || encDst != encSrc)
718 return failure();
719 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
720 return success();
721 }
722};
723
724class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
725public:
726 using OpConversionPattern::OpConversionPattern;
727 LogicalResult
728 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
729 ConversionPatternRewriter &rewriter) const override {
730 // Simply fold the operation.
731 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
732 return success();
733 }
734};
735
736/// Sparse codegen rule for the alloc operator.
737class SparseTensorAllocConverter
738 : public OpConversionPattern<bufferization::AllocTensorOp> {
739public:
740 using OpConversionPattern::OpConversionPattern;
741 SparseTensorAllocConverter(const TypeConverter &typeConverter,
742 MLIRContext *context, bool enableInit)
743 : OpConversionPattern(typeConverter, context),
744 enableBufferInitialization(enableInit) {}
745
746 LogicalResult
747 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter) const override {
749 const auto resType = getSparseTensorType(op);
750 if (!resType.hasEncoding())
751 return failure();
752
753 Location loc = op.getLoc();
754 // Deal with copy.
755 if (op.getCopy()) {
757 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
758 SmallVector<Value> fields;
759 fields.reserve(desc.getNumFields());
760 // Memcpy on memref fields.
761 for (auto field : desc.getMemRefFields()) {
762 auto memrefTp = cast<MemRefType>(field.getType());
763 auto size = memref::DimOp::create(rewriter, loc, field, 0);
764 auto copied =
765 memref::AllocOp::create(rewriter, loc, memrefTp, ValueRange{size});
766 memref::CopyOp::create(rewriter, loc, field, copied);
767 fields.push_back(copied);
768 }
769 // Reuses specifier.
770 fields.push_back(desc.getSpecifier());
771 assert(fields.size() == desc.getNumFields());
772 rewriter.replaceOpWithMultiple(op, {fields});
773 return success();
774 }
775
776 if (!resType.isIdentity()) {
777 return rewriter.notifyMatchFailure(
778 op, "try run --sparse-reinterpret-map before codegen");
779 }
780 // Level size equals to dimension size since lvl2dim map is an identity map.
781 SmallVector<Value> lvlSizesValues;
782 createDimSizes(rewriter, loc, resType,
783 flattenValues(adaptor.getDynamicSizes()),
784 /*dimSizesValues=*/lvlSizesValues);
785
786 // Construct allocation for each field.
787 Value sizeHint = op.getSizeHint();
788 SmallVector<Value> fields;
789 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
790 sizeHint, lvlSizesValues, fields);
791
792 // Replace operation with resulting memrefs.
793 rewriter.replaceOpWithMultiple(op, {fields});
794 return success();
795 }
796
797private:
798 bool enableBufferInitialization;
799};
800
801/// Sparse codegen rule for the empty tensor operator.
802class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
803public:
804 using OpConversionPattern::OpConversionPattern;
805 SparseTensorEmptyConverter(const TypeConverter &typeConverter,
806 MLIRContext *context, bool enableInit)
807 : OpConversionPattern(typeConverter, context),
808 enableBufferInitialization(enableInit) {}
809
810 LogicalResult
811 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter) const override {
813 const auto resType = getSparseTensorType(op);
814 if (!resType.hasEncoding())
815 return failure();
816
817 if (!resType.isIdentity()) {
818 return rewriter.notifyMatchFailure(
819 op, "try run --sparse-reinterpret-map before codegen");
820 }
821
822 Location loc = op.getLoc();
823 // Level size equals to dimension size since lvl2dim map is an identity map.
824 SmallVector<Value> lvlSizesValues;
825 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
826 /*dimSizesValues=*/lvlSizesValues);
827 // Construct allocation for each field.
828 Value sizeHint; // none
829 SmallVector<Value> fields;
830 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
831 sizeHint, lvlSizesValues, fields);
832
833 // Replace operation with resulting memrefs.
834 rewriter.replaceOpWithMultiple(op, {fields});
835 return success();
836 }
837
838private:
839 bool enableBufferInitialization;
840};
841
842/// Sparse codegen rule for the dealloc operator.
843class SparseTensorDeallocConverter
844 : public OpConversionPattern<bufferization::DeallocTensorOp> {
845public:
846 using OpConversionPattern::OpConversionPattern;
847 SparseTensorDeallocConverter(const TypeConverter &typeConverter,
848 MLIRContext *context, bool createDeallocs)
849 : OpConversionPattern(typeConverter, context),
850 createDeallocs(createDeallocs) {}
851
852 LogicalResult
853 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 auto enc = getSparseTensorEncoding(op.getTensor().getType());
856 if (!enc)
857 return failure();
858
859 // If user requests not to deallocate sparse tensors, simply erase the
860 // operation.
861 if (createDeallocs) {
862 // Replace the sparse tensor deallocation with field deallocations.
863 Location loc = op.getLoc();
865 adaptor.getTensor(),
866 cast<RankedTensorType>(op.getTensor().getType()));
867 for (auto input : desc.getMemRefFields())
868 // Deallocate every buffer used to store the sparse tensor handler.
869 memref::DeallocOp::create(rewriter, loc, input);
870 }
871 rewriter.eraseOp(op);
872 return success();
873 }
874
875private:
876 const bool createDeallocs;
877};
878
879/// Sparse codegen rule for tensor rematerialization.
880class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
881public:
882 using OpConversionPattern::OpConversionPattern;
883 LogicalResult
884 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
885 ConversionPatternRewriter &rewriter) const override {
886 // Prepare descriptor.
887 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
888 op.getTensor().getType());
889 // Generate optional insertion finalization code.
890 if (op.getHasInserts())
891 genEndInsert(rewriter, op.getLoc(), desc);
892 // Replace operation with resulting memrefs.
893 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
894 return success();
895 }
896};
897
898/// Sparse codegen rule for the expand op.
899class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
900public:
901 using OpConversionPattern::OpConversionPattern;
902 LogicalResult
903 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
904 ConversionPatternRewriter &rewriter) const override {
905 if (!getSparseTensorEncoding(op.getTensor().getType()))
906 return failure();
907 Location loc = op->getLoc();
908 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
909 op.getTensor().getType());
910 const auto srcType = getSparseTensorType(op.getTensor());
911 Type eltType = srcType.getElementType();
912 Type boolType = rewriter.getIntegerType(1);
913 Type idxType = rewriter.getIndexType();
914 // All initialization should be done on entry of the loop nest.
915 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
916
917 // Determine the size for access expansion (always the innermost stored
918 // level size).
919 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
920 // Generate a memref for `sz` elements of type `t`.
921 const auto genAlloc = [&](Type t) {
922 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
923 return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz});
924 };
925 // Allocate temporary buffers for values/filled-switch and added.
926 // We do not use stack buffers for this, since the expanded size may
927 // be rather large (as it envelops a single expanded dense dimension).
928 Value values = genAlloc(eltType);
929 Value filled = genAlloc(boolType);
930 Value added = genAlloc(idxType);
931 Value zero = constantZero(rewriter, loc, idxType);
932 // Reset the values/filled-switch to all-zero/false. Note that this
933 // introduces an O(N) operation into the computation, but this reset
934 // operation is amortized over the innermost loops for the access
935 // pattern expansion. As noted in the operation doc, we would like
936 // to amortize this setup cost even between kernels.
937 linalg::FillOp::create(rewriter, loc,
938 ValueRange{constantZero(rewriter, loc, eltType)},
939 ValueRange{values});
940 linalg::FillOp::create(rewriter, loc,
941 ValueRange{constantZero(rewriter, loc, boolType)},
942 ValueRange{filled});
943 // Replace expansion op with these buffers and initial coordinate.
944 assert(op.getNumResults() == 4);
945 rewriter.replaceOp(op, {values, filled, added, zero});
946 return success();
947 }
948};
949
950/// Sparse codegen rule for the compress operator.
951class SparseCompressConverter : public OpConversionPattern<CompressOp> {
952public:
953 using OpConversionPattern::OpConversionPattern;
954 LogicalResult
955 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter) const override {
957 Location loc = op->getLoc();
958 SmallVector<Value> fields;
959 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
960 op.getTensor().getType());
961 Value values = llvm::getSingleElement(adaptor.getValues());
962 Value filled = llvm::getSingleElement(adaptor.getFilled());
963 Value added = llvm::getSingleElement(adaptor.getAdded());
964 Value count = llvm::getSingleElement(adaptor.getCount());
965 const SparseTensorType dstType(desc.getRankedTensorType());
966 Type eltType = dstType.getElementType();
967
968 // If the innermost level is ordered, we need to sort the coordinates
969 // in the "added" array prior to applying the compression.
970 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
971 SortOp::create(rewriter, loc, count, added, ValueRange{},
972 rewriter.getMultiDimIdentityMap(1),
973 rewriter.getIndexAttr(0),
974 SparseTensorSortKind::HybridQuickSort);
975 // While performing the insertions, we also need to reset the elements
976 // of the values/filled-switch by only iterating over the set elements,
977 // to ensure that the runtime complexity remains proportional to the
978 // sparsity of the expanded access pattern.
979 //
980 // Generate
981 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
982 // crd = added[i];
983 // value = values[crd];
984 // insert({lvlCoords, crd}, value);
985 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
986 // values[crd] = 0;
987 // filled[crd] = false;
988 // yield new_memrefs
989 // }
990 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
991 Value i = loop.getInductionVar();
992
993 Value crd = genLoad(rewriter, loc, added, i);
994 Value value = genLoad(rewriter, loc, values, crd);
995 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
996 SmallVector<Type> flatSpTensorTps = llvm::map_to_vector(
997 desc.getFields(), [](Value v) { return v.getType(); });
998 SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
999 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1000 params.push_back(crd);
1001 params.push_back(value);
1002 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1003 params, /*genCall=*/true);
1004 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1005 genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1006 genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1007 scf::YieldOp::create(rewriter, loc, insertRet);
1008
1009 rewriter.setInsertionPointAfter(loop);
1010 // Deallocate the buffers on exit of the full loop nest.
1011 Operation *parent = getTop(op);
1012 rewriter.setInsertionPointAfter(parent);
1013 memref::DeallocOp::create(rewriter, loc, values);
1014 memref::DeallocOp::create(rewriter, loc, filled);
1015 memref::DeallocOp::create(rewriter, loc, added);
1016 // Replace operation with resulting memrefs.
1017 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1018 return success();
1019 }
1020};
1021
1022/// Sparse codegen rule for the insert operator.
1023class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1024public:
1025 using OpConversionPattern::OpConversionPattern;
1026 LogicalResult
1027 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter) const override {
1029 auto stt = getSparseTensorType(op.getDest());
1030 if (!stt.hasEncoding())
1031 return failure();
1032 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1033
1034 Location loc = op.getLoc();
1035 auto desc =
1036 getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1037 TypeRange flatSpTensorTps = desc.getFields().getTypes();
1038 SmallVector<Value> params = llvm::to_vector(desc.getFields());
1039 SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1040 params.append(flatIndices.begin(), flatIndices.end());
1041 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1042 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1043 params, /*genCall=*/true);
1044 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1045 // Replace operation with resulting memrefs.
1046 rewriter.replaceOpWithMultiple(op, {ret});
1047 return success();
1048 }
1049};
1050
1051/// Sparse codegen rule for position accesses.
1052class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1053public:
1054 using OpAdaptor = ToPositionsOp::Adaptor;
1055 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1056 LogicalResult
1057 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1058 ConversionPatternRewriter &rewriter) const override {
1059 // Replace the requested position access with corresponding field.
1060 // The view is restricted to the actual size to ensure clients
1061 // of this operation truly observe size, not capacity!
1062 Location loc = op.getLoc();
1063 Level lvl = op.getLevel();
1064 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1065 op.getTensor().getType());
1066 auto mem = desc.getPosMemRef(lvl);
1067 auto size = desc.getPosMemSize(rewriter, loc, lvl);
1068 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1069 return success();
1070 }
1071};
1072
1073/// Sparse codegen rule for accessing the coordinates arrays.
1074class SparseToCoordinatesConverter
1075 : public OpConversionPattern<ToCoordinatesOp> {
1076public:
1077 using OpAdaptor = ToCoordinatesOp::Adaptor;
1078 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1079 LogicalResult
1080 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1081 ConversionPatternRewriter &rewriter) const override {
1082 // Replace the requested coordinates access with corresponding field.
1083 // The view is restricted to the actual size to ensure clients
1084 // of this operation truly observe size, not capacity!
1085 Location loc = op.getLoc();
1086 Level lvl = op.getLevel();
1087 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1088 op.getTensor().getType());
1089 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1090 if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1091 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1092 mem = genSliceToSize(rewriter, loc, mem, size);
1093 }
1094 rewriter.replaceOp(op, mem);
1095 return success();
1096 }
1097};
1098
1099/// Sparse codegen rule for accessing the linear coordinates buffer.
1100class SparseToCoordinatesBufferConverter
1101 : public OpConversionPattern<ToCoordinatesBufferOp> {
1102public:
1103 using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
1104 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1105 LogicalResult
1106 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter) const override {
1108 // Replace the requested coordinates access with corresponding field.
1109 // The view is restricted to the actual size to ensure clients
1110 // of this operation truly observe size, not capacity!
1111 Location loc = op.getLoc();
1112 Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1113 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1114 op.getTensor().getType());
1115 auto mem = desc.getAOSMemRef();
1116 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1117 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1118 return success();
1119 }
1120};
1121
1122/// Sparse codegen rule for value accesses.
1123class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1124public:
1125 using OpAdaptor = ToValuesOp::Adaptor;
1126 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1127 LogicalResult
1128 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1129 ConversionPatternRewriter &rewriter) const override {
1130 // Replace the requested values access with corresponding field.
1131 // The view is restricted to the actual size to ensure clients
1132 // of this operation truly observe size, not capacity!
1133 Location loc = op.getLoc();
1134 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1135 op.getTensor().getType());
1136 auto mem = desc.getValMemRef();
1137 auto size = desc.getValMemSize(rewriter, loc);
1138 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1139 return success();
1140 }
1141};
1142
1143/// Sparse codegen rule for the convert operator.
1144class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1145public:
1146 using OpConversionPattern::OpConversionPattern;
1147 LogicalResult
1148 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter) const override {
1150 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1151 SparseTensorEncodingAttr encSrc =
1152 getSparseTensorEncoding(op.getSource().getType());
1153
1154 // If either the source or the destination don't have a valid sparse
1155 // tensor encoding, we should fail to legalize. This should be handled
1156 // by another set of passes before reaching here.
1157 if (!encSrc || !encDst)
1158 return failure();
1159
1160 // The output tensor can not be a slice and those cases should have been
1161 // rejected by ConvertOp::verify() already.
1162 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1163 // Different encoding (except for different bitwidth) should be handled by
1164 // rewriting.
1165 // We need further rewrites if the input tensor is a slice too.
1166 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1167 encSrc.isSlice()) {
1168 return failure();
1169 }
1170
1171 Type retElemTp = op.getResult().getType().getElementType();
1172 Type srcElemTp = op.getSource().getType().getElementType();
1173 // Fold the trivial cases.
1174 if (retElemTp == srcElemTp && encDst == encSrc) {
1175 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1176 return success();
1177 }
1178 //
1179 // Do element-wise type conversion without using InsertOp.
1180 //
1181 // for each memref in srcTensor:
1182 // dst = memref.alloc
1183 // if srcMemRefType != dstMemRefType:
1184 // for every dst[i] = cast(src[i])
1185 // else:
1186 // dst = memref.copy(src)
1187 Location loc = op.getLoc();
1188 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1189 op.getSource().getType());
1190 SmallVector<Value> fields;
1192 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1193 [&rewriter, &fields, srcDesc,
1194 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1195 LevelType /*lt*/) -> bool {
1196 // Simply reuses the storage specifier as it is an SSA value.
1197 if (fKind == SparseTensorFieldKind::StorageSpec) {
1198 fields.push_back(srcDesc.getSpecifier());
1199 } else {
1200 // Allocates new memrefs
1201 Value srcMem = srcDesc.getMemRefField(fIdx);
1202 // TODO: We can instead use the actual memSize in specifier, that
1203 // would require a subViewOp to avoid overflow when copying
1204 // values.
1205 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1206 auto dstMem = memref::AllocOp::create(rewriter, loc,
1207 cast<MemRefType>(fTp), sz);
1208 if (fTp != srcMem.getType()) {
1209 // Converts elements type.
1210 scf::buildLoopNest(
1211 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1212 constantIndex(rewriter, loc, 1),
1213 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1214 ValueRange ivs) {
1215 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1216 Value casted = genCast(builder, loc, v,
1217 dstMem.getType().getElementType());
1218 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1219 });
1220 } else {
1221 // TODO: We can even reuse the same memref for the new tensor,
1222 // but that requires a `ref-counting` based memory management
1223 // for shared memrefs between multiple sparse tensors.
1224 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1225 }
1226 fields.push_back(dstMem);
1227 }
1228 return true;
1229 });
1230
1231 rewriter.replaceOpWithMultiple(op, {fields});
1232 return success();
1233 }
1234};
1235
1236class SparseExtractSliceConverter
1237 : public OpConversionPattern<tensor::ExtractSliceOp> {
1238public:
1239 using OpConversionPattern::OpConversionPattern;
1240 LogicalResult
1241 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1242 ConversionPatternRewriter &rewriter) const override {
1243 Location loc = op.getLoc();
1244 MLIRContext *ctx = op.getContext();
1245 auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1246 auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1247 // TODO: We should check these in ExtractSliceOp::verify.
1248 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1249 return failure();
1250 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1251
1252 SmallVector<Value> fields;
1253 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1254 op.getSource().getType());
1255
1256 auto newSpec = StorageSpecifierInitOp::create(
1257 rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1258 desc.getSpecifier());
1259 desc.setSpecifier(newSpec);
1260
1261 // Fills in slice information.
1262 for (auto [idx, offset, size, stride] : llvm::enumerate(
1263 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1264 Dimension dim = idx;
1265
1266 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1267 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1268 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1269 // TODO: We could probably only set dynamic value here. But it would
1270 // requires us to fill the hole when casting a static slice to dynamic
1271 // slice.
1272 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1273 dim, offsetV);
1274
1275 // FIXME: we need to distinguish level sizes and dimension size for slices
1276 // here. Maybe we should store slice level sizes in a different array
1277 // instead of reusing it.
1278 assert(srcEnc.isIdentity());
1279 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1280 sizeV);
1281 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1282 dim, strideV);
1283 }
1284
1285 // NOTE: we can not generate tuples directly from descriptor here, as the
1286 // descriptor is holding the original type, yet we want the slice type
1287 // here (they shared every memref but with an updated specifier).
1288 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1289 return success();
1290 }
1291};
1292
1293/// Sparse codegen rule for number of entries operator.
1294class SparseNumberOfEntriesConverter
1295 : public OpConversionPattern<NumberOfEntriesOp> {
1296public:
1297 using OpConversionPattern::OpConversionPattern;
1298 LogicalResult
1299 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1300 ConversionPatternRewriter &rewriter) const override {
1301 // Query memSizes for the actually stored values.
1302 // FIXME: the nse value computed in this way might be wrong when there is
1303 // any "loose_compressed" level.
1304 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1305 op.getTensor().getType());
1306 rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1307 return success();
1308 }
1309};
1310
1311struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1312 using OpConversionPattern::OpConversionPattern;
1313 LogicalResult
1314 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1315 ConversionPatternRewriter &rewriter) const override {
1316 Location loc = op.getLoc();
1317 const auto stt = getSparseTensorType(op.getResult());
1318
1319 SmallVector<Value> fields;
1320
1322 stt,
1323 [&rewriter, &fields, &op, &stt,
1324 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1325 Level /*lvl*/, LevelType lt) -> bool {
1326 assert(fields.size() == fIdx);
1327 if (fKind == SparseTensorFieldKind::StorageSpec) {
1328 fields.push_back(
1329 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1330 } else {
1331 // Else simply takes the inputs.
1332 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1333 ? op.getValues()
1334 : op.getLevels()[fIdx];
1335 // TODO: handle batch.
1336 TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1337 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1338 // Flattens the buffer to batchLvlRank.
1339 auto reassoc = getReassociationForFlattening(
1340 mem.getType(), stt.getBatchLvlRank());
1341 mem = memref::CastOp::create(
1342 rewriter, loc, fType,
1343 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1344 } else {
1345 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1346 }
1347 fields.push_back(mem);
1348 }
1349 return true;
1350 });
1351
1352 MutSparseTensorDescriptor desc(stt, fields);
1353 Value c0 = constantIndex(rewriter, loc, 0);
1354 Value c1 = constantIndex(rewriter, loc, 1);
1355 Value c2 = constantIndex(rewriter, loc, 2);
1356 Value posBack = c0; // index to the last value in the position array
1357 Value memSize = c1; // memory size for current array
1358
1359 Level trailCOOStart = stt.getAoSCOOStart();
1360 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1361 // Sets up SparseTensorSpecifier.
1362 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1363 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1364
1365 // Sets up the level size.
1366 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1367 desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1368 // We use a single AOS array to store the trailing COO, so there is only
1369 // one memory size to set for the entire COO section.
1370 if (lvl > trailCOOStart)
1371 continue;
1372
1373 // Sets up the memory size by reading the last value in position array.
1374 LevelType lt = stt.getLvlType(lvl);
1375 // Simply forwards the position index when this is a dense level.
1376 if (lt.isa<LevelFormat::Dense>()) {
1377 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1378 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1379 continue;
1380 }
1381 if (lt.isa<LevelFormat::Batch>()) {
1382 // Skips batch levels as it is not linearized.
1383 // FIXME: this assumes that every batch has the same number of nse, need
1384 // to be generalized to handle varied-size batches.
1385 continue;
1386 }
1387
1388 if (isWithPosLT(lt)) {
1389 assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1390 if (isLooseCompressedLT(lt)) {
1391 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1392 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1393 } else {
1394 assert(isCompressedLT(lt));
1395 posBack = memSize;
1396 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1397 }
1398 desc.setPosMemSize(rewriter, loc, lvl, memSize);
1399 // The last value in position array is the memory size for next level.
1400 // FIXME: this assumes that every batch has the same number of nse, need
1401 // to be generalized to handle varied-size batches.
1402 SmallVector<Value> batched(stt.getBatchLvlRank(),
1403 constantIndex(rewriter, loc, 0));
1404 batched.push_back(posBack);
1405 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1406 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1407 }
1408 assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1409 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1410 if (lvl == trailCOOStart) {
1411 Value cooSz = arith::MulIOp::create(
1412 rewriter, loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1413 desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1414 } else {
1415 desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1416 }
1417 }
1418 desc.setValMemSize(rewriter, loc, memSize);
1419
1420 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1421 return success();
1422 }
1423};
1424
1425struct SparseDisassembleOpConverter
1426 : public OpConversionPattern<DisassembleOp> {
1427 using OpConversionPattern::OpConversionPattern;
1428 SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1429 MLIRContext *context)
1430 : OpConversionPattern(typeConverter, context) {}
1431
1432 LogicalResult
1433 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1434 ConversionPatternRewriter &rewriter) const override {
1435 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1436 op.getTensor().getType());
1437 Location loc = op.getLoc();
1438 SmallVector<Value> retMem;
1439 SmallVector<Value> retLen;
1440 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1441 &retLen](FieldIndex fid,
1443 Level lvl, LevelType lt) -> bool {
1444 if (fKind == SparseTensorFieldKind::StorageSpec)
1445 return true;
1446 SparseTensorType stt(desc.getRankedTensorType());
1447 Value sz, src;
1449 if (fKind == SparseTensorFieldKind::ValMemRef) {
1450 sz = desc.getValMemSize(rewriter, loc);
1451 src = desc.getValMemRef();
1452 dst = genToMemref(rewriter, loc, op.getOutValues());
1453
1454 retMem.push_back(dst);
1455 Type valLenTp = op.getValLen().getType();
1456 retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1457 } else {
1458 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1459 fKind == SparseTensorFieldKind::CrdMemRef);
1460
1461 sz = fKind == SparseTensorFieldKind::PosMemRef
1462 ? desc.getPosMemSize(rewriter, loc, lvl)
1463 : desc.getCrdMemSize(rewriter, loc, lvl);
1464 src = desc.getMemRefField(fid);
1465 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1466 retMem.push_back(dst);
1467 // Retrieves the corresponding level length type.
1468 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1469 retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1470 }
1471 Value flatOut = dst;
1472 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1473 auto reassoc =
1474 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1475 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1476 }
1477 Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1478 Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1479 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1480 return true;
1481 });
1482
1483 // Converts MemRefs back to Tensors.
1484 SmallVector<Value> retValues =
1485 llvm::map_to_vector(retMem, [&rewriter, loc](Value v) -> Value {
1486 return bufferization::ToTensorOp::create(
1487 rewriter, loc, memref::getTensorTypeFromMemRefType(v.getType()),
1488 v);
1489 });
1490 // Appends the actual memory length used in each buffer returned.
1491 retValues.append(retLen.begin(), retLen.end());
1492 rewriter.replaceOp(op, retValues);
1493 return success();
1494 }
1495};
1496
1497struct SparseNewConverter : public OpConversionPattern<NewOp> {
1498 using OpConversionPattern::OpConversionPattern;
1499 LogicalResult
1500 matchAndRewrite(NewOp op, OpAdaptor adaptor,
1501 ConversionPatternRewriter &rewriter) const override {
1502 Location loc = op.getLoc();
1503 const auto dstTp = getSparseTensorType(op.getResult());
1504 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1505 // are handled by rewriting.
1506 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1507 return failure();
1508
1509 // Implement as follows:
1510 // %reader = @createCheckedSparseTensorReader(%filename)
1511 // %nse = @getSparseTensorNSE(%reader)
1512 // %coo = bufferization.alloc_tensor an ordered COO with
1513 // dst dim ordering, size_hint = %nse
1514 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1515 // %values = sparse_tensor.values(%coo)
1516 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1517 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1518 // update storage specifier
1519 // @delSparseTensorReader(%reader)
1520 SmallVector<Value> dimSizesValues;
1521 Value dimSizesBuffer;
1522 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1523 dimSizesValues, dimSizesBuffer);
1524
1525 // Get the number of stored entries.
1526 const Type indexTp = rewriter.getIndexType();
1527 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1528 {indexTp}, {reader}, EmitCInterface::Off)
1529 .getResult(0);
1530
1531 // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1532 SmallVector<Value> lvlSizesValues;
1533 Value dim2lvlBuffer;
1534 Value lvl2dimBuffer;
1535 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1536 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1537
1538 // Construct allocation for each field.
1539 Value sizeHint = nse;
1540 SmallVector<Value> fields;
1541 createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1542 lvlSizesValues, fields);
1543
1544 // Read the COO tensor data.
1545 MutSparseTensorDescriptor desc(dstTp, fields);
1546 Value xs = desc.getAOSMemRef();
1547 Value ys = desc.getValMemRef();
1548 const Type boolTp = rewriter.getIntegerType(1);
1549 const Type elemTp = dstTp.getElementType();
1550 const Type crdTp = dstTp.getCrdType();
1551 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1554 Value isSorted =
1555 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1556 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1557 EmitCInterface::On)
1558 .getResult(0);
1559
1560 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1561 // data if the input elements aren't sorted yet.
1562 const Level lvlRank = dstTp.getLvlRank();
1563 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1564 Value kFalse = constantI1(rewriter, loc, false);
1565 Value notSorted = arith::CmpIOp::create(
1566 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1567 scf::IfOp ifOp =
1568 scf::IfOp::create(rewriter, loc, notSorted, /*else*/ false);
1569 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1570 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1571 SortOp::create(rewriter, loc, nse, xs, ValueRange{ys}, xPerm,
1572 rewriter.getIndexAttr(0),
1573 SparseTensorSortKind::HybridQuickSort);
1574 rewriter.setInsertionPointAfter(ifOp);
1575 }
1576
1577 // Set PosMemRef0[1] = nse.
1578 const Value c1 = constantIndex(rewriter, loc, 1);
1579 const Value posMemref0 = desc.getPosMemRef(0);
1580 const Type posTp = dstTp.getPosType();
1581 const Value posNse = genCast(rewriter, loc, nse, posTp);
1582 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1583
1584 // Update storage specifier.
1585 Value coordinatesSize = arith::MulIOp::create(
1586 rewriter, loc, nse, constantIndex(rewriter, loc, lvlRank));
1587 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1588 coordinatesSize);
1589 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1590 std::nullopt, nse);
1591
1592 // Release the sparse tensor reader.
1593 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1594 EmitCInterface::Off);
1595
1596 // Replace operation with resulting memrefs.
1597 rewriter.replaceOpWithMultiple(op, {fields});
1598 return success();
1599 }
1600};
1601
1602struct SparseHasRuntimeLibraryConverter
1603 : public OpConversionPattern<HasRuntimeLibraryOp> {
1604 using OpConversionPattern::OpConversionPattern;
1605 LogicalResult
1606 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1607 ConversionPatternRewriter &rewriter) const override {
1608 auto i1Type = rewriter.getI1Type();
1609 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1610 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1611 return success();
1612 }
1613};
1614
1615} // namespace
1616
1617//===----------------------------------------------------------------------===//
1618// Public method for populating conversion rules.
1619//===----------------------------------------------------------------------===//
1620
1621/// Populates the given patterns list with conversion rules required for
1622/// the sparsification of linear algebra operations.
1624 const TypeConverter &typeConverter, RewritePatternSet &patterns,
1625 bool createSparseDeallocs, bool enableBufferInitialization) {
1626 patterns.add<
1627 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1628 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1629 SparseCastConverter, SparseExtractSliceConverter,
1630 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1631 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1632 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1633 StorageSpecifierKind::DimOffset>,
1634 SparseSliceGetterOpConverter<ToSliceStrideOp,
1635 StorageSpecifierKind::DimStride>,
1636 SparseToPositionsConverter, SparseToCoordinatesConverter,
1637 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1638 SparseConvertConverter, SparseNewConverter,
1639 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1640 typeConverter, patterns.getContext());
1641 patterns.add<SparseTensorDeallocConverter>(
1642 typeConverter, patterns.getContext(), createSparseDeallocs);
1643 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1644 typeConverter, patterns.getContext(), enableBufferInitialization);
1645}
return success()
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
bool isCompressedLT(LevelType lt)
Definition Enums.h:415
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition Enums.h:418
unsigned FieldIndex
The type of field indices.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
Definition Enums.h:413
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326