MLIR 23.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
17
22#include "mlir/IR/Builders.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50static llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144 return;
145}
146
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
156
157 const Type specType = StorageSpecifierType::get(stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164
165 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
184 unsigned numFields = 0;
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
194 unsigned numFields = 0; // one value memref
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl) const {
210 unsigned stride = 1;
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(getOffset());
269 os << ", ";
270 os << getStaticString(getSize());
271 os << ", ";
272 os << getStaticString(getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(printer.getStream());
278}
279
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 parser.getCurrentLocation(),
287 "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(parser.parseLParen()) ||
302 failed(parseOptionalStaticSlice(offset, parser)) ||
303 failed(parser.parseComma()) ||
304 failed(parseOptionalStaticSlice(size, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(stride, parser)) ||
307 failed(parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311 offset, size, stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(0, 0);
354}
355
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
362}
363
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365 return withExplicitVal(Attribute());
366}
367
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
374}
375
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377 return withImplicitVal(Attribute());
378}
379
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385}
386
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389}
390
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
395}
396
397bool SparseTensorEncodingAttr::isAllDense() const {
398 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399}
400
401bool SparseTensorEncodingAttr::isAllOrdered() const {
402 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403}
404
405Type SparseTensorEncodingAttr::getCrdElemType() const {
406 if (!getImpl())
407 return nullptr;
408 if (getCrdWidth())
409 return IntegerType::get(getContext(), getCrdWidth());
410 return IndexType::get(getContext());
411}
412
413Type SparseTensorEncodingAttr::getPosElemType() const {
414 if (!getImpl())
415 return nullptr;
416 if (getPosWidth())
417 return IntegerType::get(getContext(), getPosWidth());
418 return IndexType::get(getContext());
419}
420
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape) const {
423 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424 return MemRefType::get(shape, getCrdElemType());
425}
426
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape) const {
429 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430 return MemRefType::get(shape, getPosElemType());
431}
432
433bool SparseTensorEncodingAttr::isIdentity() const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435}
436
437bool SparseTensorEncodingAttr::isPermutation() const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439}
440
441Dimension SparseTensorEncodingAttr::getDimRank() const {
442 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445}
446
447Level SparseTensorEncodingAttr::getLvlRank() const {
448 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
450}
451
452LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453 if (!getImpl())
454 return LevelFormat::Batch;
455 assert(l < getLvlRank() && "Level is out of bounds");
456 return getLvlTypes()[l];
457}
458
459bool SparseTensorEncodingAttr::isSlice() const {
460 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
462}
463
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466 assert(isSlice() && "Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() && "Dimension is out of bounds");
469 return dimSlices[dim];
470}
471
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474 return getDimSlice(dim).getStaticOffset();
475}
476
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479 return getDimSlice(dim).getStaticStride();
480}
481
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484 return getStaticDimSliceOffset(toDim(*this, lvl));
485}
486
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489 return getStaticDimSliceStride(toDim(*this, lvl));
490}
491
492SmallVector<int64_t>
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir) const {
495 if (isIdentity())
496 return SmallVector<int64_t>(srcShape);
497
498 SmallVector<int64_t> ret;
499 unsigned rank =
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501 ret.reserve(rank);
502
503 if (isPermutation()) {
504 for (unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506 : toLvl(*this, r);
507 ret.push_back(srcShape[trans]);
508 }
509 return ret;
510 }
511
512 // Handle non-permutation maps.
513 AffineMap transMap =
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
520 // Push back the max coordinate for the given dimension/level size.
521 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
522 } else {
523 // A dynamic size, use a AffineDimExpr to symbolize the value.
524 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
525 }
526 };
527
528 // The number of symbols information is included inside the `dimToLvl` map
529 // during parsing. Here, we're extracting it to be used when simplifying the
530 // affine expression.
531 unsigned numSymbols = getDimToLvl().getNumSymbols();
532
533 for (AffineExpr exp : transMap.getResults()) {
534 // Do constant propagation on the affine map.
535 AffineExpr evalExp = simplifyAffineExpr(exp.replaceDims(dimRep),
536 srcShape.size(), numSymbols);
537 // use llvm namespace here to avoid ambiguity
538 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
539 ret.push_back(c.getValue() + 1);
540 } else {
541 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
542 mod && mod.getKind() == AffineExprKind::Mod) {
543 // We can still infer a static bound for expressions in form
544 // "d % constant" since d % constant \in [0, constant).
545 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
546 ret.push_back(bound.getValue());
547 continue;
548 }
549 }
550 ret.push_back(ShapedType::kDynamic);
551 }
552 }
553 assert(ret.size() == rank);
554 return ret;
555}
556
558SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
559 ValueRange crds,
560 CrdTransDirectionKind dir) const {
561 if (!getImpl())
562 return crds;
563
564 SmallVector<Type> retType(
565 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
566 builder.getIndexType());
567 auto transOp =
568 CrdTranslateOp::create(builder, loc, retType, crds, dir, *this);
569 return transOp.getOutCrds();
570}
571
572Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
573 // Open "<{" part.
574 if (failed(parser.parseLess()))
575 return {};
576 if (failed(parser.parseLBrace()))
577 return {};
578
579 // Process the data from the parsed dictionary value into struct-like data.
580 SmallVector<LevelType> lvlTypes;
581 SmallVector<SparseTensorDimSliceAttr> dimSlices;
582 AffineMap dimToLvl = {};
583 AffineMap lvlToDim = {};
584 unsigned posWidth = 0;
585 unsigned crdWidth = 0;
586 Attribute explicitVal;
587 Attribute implicitVal;
588 StringRef attrName;
589 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
590 "explicitVal", "implicitVal"};
591 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
592 // Detect admissible keyword.
593 auto *it = find(keys, attrName);
594 if (it == keys.end()) {
595 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
596 return {};
597 }
598 unsigned keyWordIndex = it - keys.begin();
599 // Consume the `=` after keys
600 if (failed(parser.parseEqual()))
601 return {};
602 // Dispatch on keyword.
603 switch (keyWordIndex) {
604 case 0: { // map
605 ir_detail::DimLvlMapParser cParser(parser);
606 auto res = cParser.parseDimLvlMap();
607 if (failed(res))
608 return {};
609 const auto &dlm = *res;
610
611 const Level lvlRank = dlm.getLvlRank();
612 for (Level lvl = 0; lvl < lvlRank; lvl++)
613 lvlTypes.push_back(dlm.getLvlType(lvl));
614
615 const Dimension dimRank = dlm.getDimRank();
616 for (Dimension dim = 0; dim < dimRank; dim++)
617 dimSlices.push_back(dlm.getDimSlice(dim));
618 // NOTE: the old syntax requires an all-or-nothing approach to
619 // `dimSlices`; therefore, if any slice actually exists then we need
620 // to convert null-DSA into default/nop DSA.
621 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
622 return static_cast<bool>(slice.getImpl());
623 };
624 if (llvm::any_of(dimSlices, isDefined)) {
625 const auto defaultSlice =
626 SparseTensorDimSliceAttr::get(parser.getContext());
627 for (Dimension dim = 0; dim < dimRank; dim++)
628 if (!isDefined(dimSlices[dim]))
629 dimSlices[dim] = defaultSlice;
630 } else {
631 dimSlices.clear();
632 }
633
634 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
635 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
636 break;
637 }
638 case 1: { // posWidth
639 Attribute attr;
640 if (failed(parser.parseAttribute(attr)))
641 return {};
642 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
643 if (!intAttr) {
644 parser.emitError(parser.getNameLoc(),
645 "expected an integral position bitwidth");
646 return {};
647 }
648 posWidth = intAttr.getInt();
649 break;
650 }
651 case 2: { // crdWidth
652 Attribute attr;
653 if (failed(parser.parseAttribute(attr)))
654 return {};
655 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
656 if (!intAttr) {
657 parser.emitError(parser.getNameLoc(),
658 "expected an integral index bitwidth");
659 return {};
660 }
661 crdWidth = intAttr.getInt();
662 break;
663 }
664 case 3: { // explicitVal
665 Attribute attr;
666 if (failed(parser.parseAttribute(attr)))
667 return {};
668 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
669 explicitVal = result;
670 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
671 explicitVal = result;
672 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
673 explicitVal = result;
674 } else {
675 parser.emitError(parser.getNameLoc(),
676 "expected a numeric value for explicitVal");
677 return {};
678 }
679 break;
680 }
681 case 4: { // implicitVal
682 Attribute attr;
683 if (failed(parser.parseAttribute(attr)))
684 return {};
685 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
686 implicitVal = result;
687 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
688 implicitVal = result;
689 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
690 implicitVal = result;
691 } else {
692 parser.emitError(parser.getNameLoc(),
693 "expected a numeric value for implicitVal");
694 return {};
695 }
696 break;
697 }
698 } // switch
699 // Only last item can omit the comma.
700 if (parser.parseOptionalComma().failed())
701 break;
702 }
703
704 // Close "}>" part.
705 if (failed(parser.parseRBrace()))
706 return {};
707 if (failed(parser.parseGreater()))
708 return {};
709
710 // Construct struct-like storage for attribute.
711 if (!lvlToDim || lvlToDim.isEmpty()) {
712 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
713 }
714 return parser.getChecked<SparseTensorEncodingAttr>(
715 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
716 explicitVal, implicitVal, dimSlices);
717}
718
719void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
720 auto map = static_cast<AffineMap>(getDimToLvl());
721 // Empty affine map indicates identity map
722 if (!map)
723 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
724 printer << "<{ map = ";
725 printSymbols(map, printer);
726 printer << '(';
727 printDimensions(map, printer, getDimSlices());
728 printer << ") -> (";
729 printLevels(map, printer, getLvlTypes());
730 printer << ')';
731 // Print remaining members only for non-default values.
732 if (getPosWidth())
733 printer << ", posWidth = " << getPosWidth();
734 if (getCrdWidth())
735 printer << ", crdWidth = " << getCrdWidth();
736 if (getExplicitVal()) {
737 printer << ", explicitVal = " << getExplicitVal();
738 }
739 if (getImplicitVal())
740 printer << ", implicitVal = " << getImplicitVal();
741 printer << " }>";
742}
743
744void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
745 AsmPrinter &printer) const {
746 if (map.getNumSymbols() == 0)
747 return;
748 printer << '[';
749 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
750 printer << 's' << i << ", ";
751 if (map.getNumSymbols() >= 1)
752 printer << 's' << map.getNumSymbols() - 1;
753 printer << ']';
754}
755
756void SparseTensorEncodingAttr::printDimensions(
757 AffineMap &map, AsmPrinter &printer,
758 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
759 if (!dimSlices.empty()) {
760 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
761 printer << 'd' << i << " : " << dimSlices[i] << ", ";
762 if (map.getNumDims() >= 1) {
763 printer << 'd' << map.getNumDims() - 1 << " : "
764 << dimSlices[map.getNumDims() - 1];
765 }
766 } else {
767 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
768 printer << 'd' << i << ", ";
769 if (map.getNumDims() >= 1)
770 printer << 'd' << map.getNumDims() - 1;
771 }
772}
773
774void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
775 ArrayRef<LevelType> lvlTypes) const {
776 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
777 map.getResult(i).print(printer.getStream());
778 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
779 }
780 if (map.getNumResults() >= 1) {
781 auto lastIndex = map.getNumResults() - 1;
782 map.getResult(lastIndex).print(printer.getStream());
783 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
784 }
785}
786
787LogicalResult SparseTensorEncodingAttr::verify(
788 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
789 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
790 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
791 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
792 if (!acceptBitWidth(posWidth))
793 return emitError() << "unexpected position bitwidth: " << posWidth;
794 if (!acceptBitWidth(crdWidth))
795 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
796
797 // Verify every COO segment.
798 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
799 while (it != lvlTypes.end()) {
800 if (it == lvlTypes.begin() ||
802 return emitError() << "expected compressed or loose_compressed level "
803 "before singleton level";
804
805 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
806 if (!std::all_of(it, curCOOEnd, isSingletonLT))
807 return emitError() << "expected all singleton lvlTypes "
808 "following a singleton level";
809 // We can potentially support mixed SoA/AoS singleton levels.
810 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
811 return it->isa<LevelPropNonDefault::SoA>() ==
813 })) {
814 return emitError() << "expected all singleton lvlTypes stored in the "
815 "same memory layout (SoA vs AoS).";
816 }
817 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
818 }
819
820 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
821 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
822 return emitError() << "Batch lvlType can only be leading levels.";
823
824 // SoA property can only be applied on singleton level.
825 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
826 return lt.isa<LevelPropNonDefault::SoA>();
827 });
828 if (llvm::any_of(soaLvls, [](LevelType lt) {
829 return !lt.isa<LevelFormat::Singleton>();
830 })) {
831 return emitError() << "SoA is only applicable to singleton lvlTypes.";
832 }
833
834 // Dense levels cannot follow a non-unique level. The iteration model for
835 // dense levels requires exactly one parent position to linearize into a
836 // contiguous range, but a non-unique parent provides two cursor values
837 // (segment start and end), which the dense level cannot handle.
838 for (auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
839 if (isDenseLT(lt) && !isUniqueLT(lvlTypes[i - 1]))
840 return emitError() << "dense level cannot follow a non-unique level";
841 }
842
843 // TODO: audit formats that actually are supported by backend.
844 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
845 it != std::end(lvlTypes)) {
846 if (it != lvlTypes.end() - 1)
847 return emitError() << "expected n_out_of_m to be the last level type";
848 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
849 return emitError() << "expected all dense lvlTypes "
850 "before a n_out_of_m level";
851 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
852 if (!isBlockSparsity(dimToLvl)) {
853 return emitError()
854 << "expected 1xm block structure for n_out_of_m level";
855 }
856 auto sizes = getBlockSize(dimToLvl);
857 unsigned coefficient = 0;
858 for (const auto &elem : sizes) {
859 if (elem != 0) {
860 if (elem != coefficient && coefficient != 0) {
861 return emitError() << "expected only one blocked level "
862 "with the same coefficients";
863 }
864 coefficient = elem;
865 }
866 }
867 if (coefficient != getM(*it)) {
868 return emitError() << "expected coeffiencts of Affine expressions "
869 "to be equal to m of n_out_of_m level";
870 }
871 }
872 }
873 // Before we can check that the level-rank is consistent/coherent
874 // across all fields, we need to define it. The source-of-truth for
875 // the `getLvlRank` method is the length of the level-types array,
876 // since it must always be provided and have full rank; therefore we
877 // use that same source-of-truth here.
878 const Level lvlRank = lvlTypes.size();
879 if (lvlRank == 0)
880 return emitError() << "expected a non-empty array for lvlTypes";
881 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
882 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
883 if (dimToLvl) {
884 if (dimToLvl.getNumResults() != lvlRank)
885 return emitError()
886 << "level-rank mismatch between dimToLvl and lvlTypes: "
887 << dimToLvl.getNumResults() << " != " << lvlRank;
888 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
889 // Symbols can't be inferred but are acceptable.
890 if (!inferRes && dimToLvl.getNumSymbols() == 0)
891 return emitError() << "failed to infer lvlToDim from dimToLvl";
892 if (lvlToDim && (inferRes != lvlToDim))
893 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
894 if (dimRank > lvlRank)
895 return emitError() << "unexpected dimToLvl mapping from " << dimRank
896 << " to " << lvlRank;
897 }
898 if (!dimSlices.empty()) {
899 if (dimSlices.size() != dimRank)
900 return emitError()
901 << "dimension-rank mismatch between dimSlices and dimToLvl: "
902 << dimSlices.size() << " != " << dimRank;
903 // Compiler support for `dimSlices` currently requires that the two
904 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
905 if (dimRank != lvlRank)
906 return emitError()
907 << "dimSlices expected dimension-rank to match level-rank: "
908 << dimRank << " != " << lvlRank;
909 }
910 return success();
911}
912
913LogicalResult SparseTensorEncodingAttr::verifyEncoding(
914 ArrayRef<Size> dimShape, Type elementType,
915 function_ref<InFlightDiagnostic()> emitError) const {
916 // Check structural integrity. In particular, this ensures that the
917 // level-rank is coherent across all the fields.
918 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
919 getPosWidth(), getCrdWidth(), getExplicitVal(),
920 getImplicitVal(), getDimSlices())))
921 return failure();
922 // Check integrity with tensor type specifics. In particular, we
923 // need only check that the dimension-rank of the tensor agrees with
924 // the dimension-rank of the encoding.
925 const Dimension dimRank = dimShape.size();
926 if (dimRank == 0)
927 return emitError() << "expected non-scalar sparse tensor";
928 if (getDimRank() != dimRank)
929 return emitError()
930 << "dimension-rank mismatch between encoding and tensor shape: "
931 << getDimRank() << " != " << dimRank;
932 if (auto expVal = getExplicitVal()) {
933 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
934 if (attrType != elementType) {
935 return emitError() << "explicit value type mismatch between encoding and "
936 << "tensor element type: " << attrType
937 << " != " << elementType;
938 }
939 }
940 if (auto impVal = getImplicitVal()) {
941 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
942 if (attrType != elementType) {
943 return emitError() << "implicit value type mismatch between encoding and "
944 << "tensor element type: " << attrType
945 << " != " << elementType;
946 }
947 // Currently, we only support zero as the implicit value.
948 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
949 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
950 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
951 if ((impFVal && impFVal.getValue().isNonZero()) ||
952 (impIntVal && !impIntVal.getValue().isZero()) ||
953 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
954 impComplexVal.getReal().isNonZero()))) {
955 return emitError() << "implicit value must be zero";
956 }
957 }
958 return success();
959}
960
961Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
962 SmallVector<COOSegment> coo = getCOOSegments();
963 assert(coo.size() == 1 || coo.empty());
964 if (!coo.empty() && coo.front().isAoS()) {
965 return coo.front().lvlRange.first;
966 }
967 return getLvlRank();
968}
969
970SmallVector<COOSegment>
971mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
972 SmallVector<COOSegment> ret;
973 if (getLvlRank() <= 1)
974 return ret;
975
976 ArrayRef<LevelType> lts = getLvlTypes();
977 Level l = 0;
978 while (l < getLvlRank()) {
979 auto lt = lts[l];
981 auto cur = lts.begin() + l;
982 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
983 return !lt.isa<LevelFormat::Singleton>();
984 });
985 unsigned cooLen = std::distance(cur, end);
986 if (cooLen > 1) {
987 // To support mixed SoA/AoS COO, we should break the segment when the
988 // storage scheme changes, for now we faithfully assume that all
989 // consecutive singleton levels have the same storage format as verified
990 // STEA.
991 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
992 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
993 }
994 l += cooLen;
995 } else {
996 l++;
997 }
998 }
999 return ret;
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// SparseTensorType Methods.
1004//===----------------------------------------------------------------------===//
1005
1007 bool isUnique) const {
1008 if (!hasEncoding())
1009 return false;
1010 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1011 return false;
1012 for (Level l = startLvl + 1; l < lvlRank; ++l)
1013 if (!isSingletonLvl(l))
1014 return false;
1015 // If isUnique is true, then make sure that the last level is unique,
1016 // that is, when lvlRank == 1, the only compressed level is unique,
1017 // and when lvlRank > 1, the last singleton is unique.
1018 return !isUnique || isUniqueLvl(lvlRank - 1);
1019}
1020
1021RankedTensorType
1023 SmallVector<LevelType> lvlTypes;
1024 lvlTypes.reserve(lvlRank);
1025 // A non-unique compressed level at beginning (unless this is
1026 // also the last level, then it is unique).
1027 lvlTypes.push_back(
1028 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1029 if (lvlRank > 1) {
1030 // Followed by n-2 non-unique singleton levels.
1031 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1032 *buildLevelType(LevelFormat::Singleton, ordered, false));
1033 // Ends by a unique singleton level.
1034 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1035 }
1036 auto enc = SparseTensorEncodingAttr::get(
1037 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1039 return RankedTensorType::get(getDimShape(), getElementType(), enc);
1040}
1041
1042//===----------------------------------------------------------------------===//
1043// Convenience Methods.
1044//===----------------------------------------------------------------------===//
1045
1046SparseTensorEncodingAttr
1048 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1049 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1050 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1051 return mdtp.getEncoding();
1052 return nullptr;
1053}
1054
1056 MLIRContext *context) {
1057 auto map = static_cast<AffineMap>(dimToLvl);
1058 AffineMap lvlToDim;
1059 // Return an empty lvlToDim when inference is not successful.
1060 if (!map || map.getNumSymbols() != 0) {
1061 lvlToDim = AffineMap();
1062 } else if (map.isPermutation()) {
1063 lvlToDim = inversePermutation(map);
1064 } else if (isBlockSparsity(map)) {
1065 lvlToDim = inverseBlockSparsity(map, context);
1066 }
1067 return lvlToDim;
1068}
1069
1071 MLIRContext *context) {
1072 SmallVector<AffineExpr> lvlExprs;
1073 auto numLvls = dimToLvl.getNumResults();
1074 lvlExprs.reserve(numLvls);
1075 // lvlExprComponents stores information of the floordiv and mod operations
1076 // applied to the same dimension, so as to build the lvlToDim map.
1077 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1078 for (unsigned i = 0, n = numLvls; i < n; i++) {
1079 auto result = dimToLvl.getResult(i);
1080 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1081 if (result.getKind() == AffineExprKind::FloorDiv) {
1082 // Position of the dimension in dimToLvl.
1083 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1085 "expected only one floordiv for each dimension");
1086 SmallVector<AffineExpr, 3> components;
1087 // Level variable for floordiv.
1088 components.push_back(getAffineDimExpr(i, context));
1089 // Multiplier.
1090 components.push_back(binOp.getRHS());
1091 // Map key is the position of the dimension.
1092 lvlExprComponents[pos] = components;
1093 } else if (result.getKind() == AffineExprKind::Mod) {
1094 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1095 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1096 "expected floordiv before mod");
1097 // Add level variable for mod to the same vector
1098 // of the corresponding floordiv.
1099 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1100 } else {
1101 assert(false && "expected floordiv or mod");
1102 }
1103 } else {
1104 lvlExprs.push_back(getAffineDimExpr(i, context));
1105 }
1106 }
1107 // Build lvlExprs from lvlExprComponents.
1108 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1109 // would be [il, 2, ii]. It could be used to build the AffineExpr
1110 // i = il * 2 + ii in lvlToDim.
1111 for (auto &components : lvlExprComponents) {
1112 assert(components.second.size() == 3 &&
1113 "expected 3 components to build lvlExprs");
1114 auto mulOp = getAffineBinaryOpExpr(
1115 AffineExprKind::Mul, components.second[0], components.second[1]);
1116 auto addOp =
1117 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1118 lvlExprs.push_back(addOp);
1119 }
1120 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1121}
1122
1124 assert(isBlockSparsity(dimToLvl) &&
1125 "expected dimToLvl to be block sparsity for calling getBlockSize");
1126 SmallVector<unsigned> blockSize;
1127 for (auto result : dimToLvl.getResults()) {
1128 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1129 if (result.getKind() == AffineExprKind::Mod) {
1130 blockSize.push_back(
1131 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1132 }
1133 } else {
1134 blockSize.push_back(0);
1135 }
1136 }
1137 return blockSize;
1138}
1139
1141 if (!dimToLvl)
1142 return false;
1143 std::map<unsigned, int64_t> coeffientMap;
1144 bool hasBlock = false;
1145 for (auto result : dimToLvl.getResults()) {
1146 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1147 // Check for "dim op const".
1148 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1149 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1150 if (!dimOp || !conOp || conOp.getValue() <= 0)
1151 return false;
1152 // Inspect "dim / const" or "dim % const".
1153 auto pos = dimOp.getPosition();
1154 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1155 // Expect only one floordiv for each dimension.
1156 auto [it, inserted] = coeffientMap.try_emplace(pos);
1157 if (!inserted)
1158 return false;
1159 // Record coefficient of the floordiv.
1160 it->second = conOp.getValue();
1161 } else if (binOp.getKind() == AffineExprKind::Mod) {
1162 // Expect floordiv before mod.
1163 auto it = coeffientMap.find(pos);
1164 if (it == coeffientMap.end())
1165 return false;
1166 // Expect mod to have the same coefficient as floordiv.
1167 if (conOp.getValue() != it->second)
1168 return false;
1169 hasBlock = true;
1170 } else {
1171 return false;
1172 }
1173 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1174 auto pos = dimOp.getPosition();
1175 // Expect dim to be unset.
1176 if (!coeffientMap.try_emplace(pos, 0).second)
1177 return false;
1178 } else {
1179 return false;
1180 }
1181 }
1182 return hasBlock;
1183}
1184
1186 auto hasNonIdentityMap = [](Value v) {
1187 auto stt = tryGetSparseTensorType(v);
1188 return stt && !stt->isIdentity();
1189 };
1190
1191 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1192 llvm::any_of(op->getResults(), hasNonIdentityMap);
1193}
1194
1195Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1196 if (enc) {
1197 assert(enc.isPermutation() && "Non permutation map not supported");
1198 if (const auto dimToLvl = enc.getDimToLvl())
1199 return dimToLvl.getDimPosition(l);
1200 }
1201 return l;
1202}
1203
1204Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1205 if (enc) {
1206 assert(enc.isPermutation() && "Non permutation map not supported");
1207 if (const auto lvlToDim = enc.getLvlToDim())
1208 return lvlToDim.getDimPosition(d);
1209 }
1210 return d;
1211}
1212
1213/// We normalized sparse tensor encoding attribute by always using
1214/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1215/// as other variants) lead to the same storage specifier type, and stripping
1216/// irrelevant fields that do not alter the sparse tensor memory layout.
1217static SparseTensorEncodingAttr
1218getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1220 for (auto lt : enc.getLvlTypes())
1221 lts.push_back(lt.stripStorageIrrelevantProperties());
1222
1223 return SparseTensorEncodingAttr::get(
1224 enc.getContext(), lts,
1225 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1226 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1227 // Always use `index` for memSize and lvlSize instead of reusing
1228 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1229 // value for different bitwidth, it also avoids casting between index and
1230 // integer (returned by DimOp)
1231 0, 0,
1232 Attribute(), // explicitVal (irrelevant to storage specifier)
1233 Attribute(), // implicitVal (irrelevant to storage specifier)
1234 enc.getDimSlices());
1235}
1236
1237StorageSpecifierType
1238StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1239 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1240}
1241
1242StorageSpecifierType
1243StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1244 MLIRContext *ctx,
1245 SparseTensorEncodingAttr encoding) {
1246 return Base::getChecked(emitError, ctx,
1248}
1249
1250//===----------------------------------------------------------------------===//
1251// SparseTensorDialect Operations.
1252//===----------------------------------------------------------------------===//
1253
1254static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1255 return success(lvl < getSparseTensorType(tensor).getLvlRank());
1256}
1257
1258static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1259 const Type etp = getMemRefType(mem).getElementType();
1260 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1261}
1262
1263static LogicalResult verifySparsifierGetterSetter(
1264 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1266 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1267 return op->emitError(
1268 "redundant level argument for querying value memory size");
1269 }
1270
1271 const auto enc = md.getType().getEncoding();
1272 const Level lvlRank = enc.getLvlRank();
1273
1274 if (mdKind == StorageSpecifierKind::DimOffset ||
1275 mdKind == StorageSpecifierKind::DimStride)
1276 if (!enc.isSlice())
1277 return op->emitError("requested slice data on non-slice tensor");
1278
1279 if (mdKind != StorageSpecifierKind::ValMemSize) {
1280 if (!lvl)
1281 return op->emitError("missing level argument");
1282
1283 const Level l = lvl.value();
1284 if (l >= lvlRank)
1285 return op->emitError("requested level is out of bounds");
1286
1287 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1288 return op->emitError(
1289 "requested position memory size on a singleton level");
1290 }
1291 return success();
1292}
1293
1295 switch (kind) {
1297 return stt.getCrdType();
1299 return stt.getPosType();
1301 return stt.getElementType();
1303 return nullptr;
1304 }
1305 llvm_unreachable("Unrecognizable FieldKind");
1306}
1307
1308static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1309 SparseTensorType stt,
1310 RankedTensorType valTp,
1311 TypeRange lvlTps) {
1312 if (requiresStaticShape && !stt.hasStaticDimShape())
1313 return op->emitError("the sparse-tensor must have static shape");
1314 if (!stt.hasEncoding())
1315 return op->emitError("the sparse-tensor must have an encoding attribute");
1316
1317 // Verifies the trailing COO.
1318 Level cooStartLvl = stt.getAoSCOOStart();
1319 if (cooStartLvl < stt.getLvlRank()) {
1320 // We only supports trailing COO for now, must be the last input.
1321 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1322 // The coordinates should be in shape of <? x rank>
1323 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1324 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1325 return op->emitError("input/output trailing COO level-ranks don't match");
1326 }
1327 }
1328
1329 // Verifies that all types match.
1330 StorageLayout layout(stt.getEncoding());
1331 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1332 return op->emitError("inconsistent number of fields between input/output");
1333
1334 unsigned idx = 0;
1335 bool misMatch = false;
1336 layout.foreachField([&idx, &misMatch, stt, valTp,
1337 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1338 Level lvl, LevelType lt) -> bool {
1340 return true;
1341
1342 Type inputTp = nullptr;
1343 if (fKind == SparseTensorFieldKind::ValMemRef) {
1344 inputTp = valTp;
1345 } else {
1346 assert(fid == idx && stt.getLvlType(lvl) == lt);
1347 inputTp = lvlTps[idx++];
1348 }
1349 // The input element type and expected element type should match.
1350 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1351 Type expElemTp = getFieldElemType(stt, fKind);
1352 if (inpElemTp != expElemTp) {
1353 misMatch = true;
1354 return false; // to terminate the iteration
1355 }
1356 return true;
1357 });
1358
1359 if (misMatch)
1360 return op->emitError("input/output element-types don't match");
1361 return success();
1362}
1363
1364LogicalResult AssembleOp::verify() {
1365 RankedTensorType valuesTp = getValues().getType();
1366 const auto lvlsTp = getLevels().getTypes();
1367 const auto resTp = getSparseTensorType(getResult());
1368 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1369}
1370
1371LogicalResult DisassembleOp::verify() {
1372 if (getOutValues().getType() != getRetValues().getType())
1373 return emitError("output values and return value type mismatch");
1374
1375 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1376 if (ot.getType() != rt.getType())
1377 return emitError("output levels and return levels type mismatch");
1378
1379 RankedTensorType valuesTp = getRetValues().getType();
1380 const auto lvlsTp = getRetLevels().getTypes();
1381 const auto srcTp = getSparseTensorType(getTensor());
1382 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1383}
1384
1385LogicalResult ConvertOp::verify() {
1386 RankedTensorType tp1 = getSource().getType();
1387 RankedTensorType tp2 = getDest().getType();
1388 if (tp1.getRank() != tp2.getRank())
1389 return emitError("unexpected conversion mismatch in rank");
1390 auto dstEnc =
1391 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1392 if (dstEnc && dstEnc.isSlice())
1393 return emitError("cannot convert to a sparse tensor slice");
1394
1395 auto shape1 = tp1.getShape();
1396 auto shape2 = tp2.getShape();
1397 // Accept size matches between the source and the destination type
1398 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1399 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1400 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1401 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1402 return emitError("unexpected conversion mismatch in dimension ") << d;
1403 return success();
1404}
1405
1406OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1407 if (getType() == getSource().getType())
1408 return getSource();
1409 return {};
1410}
1411
1412bool ConvertOp::needsExtraSort() {
1413 SparseTensorType srcStt = getSparseTensorType(getSource());
1414 SparseTensorType dstStt = getSparseTensorType(getDest());
1415
1416 // We do not need an extra sort when returning unordered sparse tensors or
1417 // dense tensor since dense tensor support random access.
1418 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1419 return false;
1420
1421 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1422 srcStt.hasSameDimToLvl(dstStt)) {
1423 return false;
1424 }
1425
1426 // Source and dest tensors are ordered in different ways. We only do direct
1427 // dense to sparse conversion when the dense input is defined by a sparse
1428 // constant. Note that we can theoretically always directly convert from dense
1429 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1430 // performance.
1431 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1432 if (isa<SparseElementsAttr>(constOp.getValue()))
1433 return false;
1434
1435 return true;
1436}
1437
1438LogicalResult CrdTranslateOp::verify() {
1439 uint64_t inRank = getEncoder().getLvlRank();
1440 uint64_t outRank = getEncoder().getDimRank();
1441
1442 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1443 std::swap(inRank, outRank);
1444
1445 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1446 return emitError("Coordinate rank mismatch with encoding");
1447
1448 return success();
1449}
1450
1451LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1452 SmallVectorImpl<OpFoldResult> &results) {
1453 if (getEncoder().isIdentity()) {
1454 results.assign(getInCrds().begin(), getInCrds().end());
1455 return success();
1456 }
1457 if (getEncoder().isPermutation()) {
1458 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1459 ? getEncoder().getDimToLvl()
1460 : getEncoder().getLvlToDim();
1461 for (AffineExpr exp : perm.getResults())
1462 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1463 return success();
1464 }
1465
1466 // Fuse dim2lvl/lvl2dim pairs.
1467 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1468 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1469 return v.getDefiningOp() == def;
1470 });
1471 if (!sameDef)
1472 return failure();
1473
1474 bool oppositeDir = def.getDirection() != getDirection();
1475 bool sameOracle =
1476 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1477 bool sameCount = def.getNumResults() == getInCrds().size();
1478 if (!oppositeDir || !sameOracle || !sameCount)
1479 return failure();
1480
1481 // The definition produces the coordinates in the same order as the input
1482 // coordinates.
1483 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1484 [](auto valuePair) {
1485 auto [lhs, rhs] = valuePair;
1486 return lhs == rhs;
1487 });
1488
1489 if (!sameOrder)
1490 return failure();
1491 // l1 = dim2lvl (lvl2dim l0)
1492 // ==> l0
1493 results.append(def.getInCrds().begin(), def.getInCrds().end());
1494 return success();
1495}
1496
1497void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1498 int64_t index) {
1499 Value val = arith::ConstantIndexOp::create(builder, state.location, index);
1500 return build(builder, state, source, val);
1501}
1502
1503LogicalResult LvlOp::verify() {
1504 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1505 auto stt = getSparseTensorType(getSource());
1506 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1507 return emitError(
1508 "Level index exceeds the rank of the input sparse tensor");
1509 }
1510 return success();
1511}
1512
1513std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1514 return getConstantIntValue(getIndex());
1515}
1516
1517Speculation::Speculatability LvlOp::getSpeculatability() {
1518 auto constantIndex = getConstantLvlIndex();
1519 if (!constantIndex)
1521
1522 assert(constantIndex <
1523 cast<RankedTensorType>(getSource().getType()).getRank());
1525}
1526
1527OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1528 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1529 if (!lvlIndex)
1530 return {};
1531
1532 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1533 auto stt = getSparseTensorType(getSource());
1534 if (lvl >= stt.getLvlRank()) {
1535 // Follows the same convention used by tensor.dim operation. Out of bound
1536 // indices produce undefined behavior but are still valid IR. Don't choke on
1537 // them.
1538 return {};
1539 }
1540
1541 // Helper lambda to build an IndexAttr.
1542 auto getIndexAttr = [this](int64_t lvlSz) {
1543 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1544 };
1545
1546 SmallVector<Size> lvlShape = stt.getLvlShape();
1547 if (ShapedType::isStatic(lvlShape[lvl]))
1548 return getIndexAttr(lvlShape[lvl]);
1549
1550 return {};
1551}
1552
1553void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1554 SparseTensorEncodingAttr dstEnc, Value source) {
1555 auto srcStt = getSparseTensorType(source);
1556 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1557 SmallVector<int64_t> dstDimShape =
1558 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1559 auto dstTp =
1560 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1561 return build(odsBuilder, odsState, dstTp, source);
1562}
1563
1564LogicalResult ReinterpretMapOp::verify() {
1565 auto srcStt = getSparseTensorType(getSource());
1566 auto dstStt = getSparseTensorType(getDest());
1567 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1568 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1569
1570 if (srcLvlTps.size() != dstLvlTps.size())
1571 return emitError("Level rank mismatch between source/dest tensors");
1572
1573 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1574 if (srcLvlTp != dstLvlTp)
1575 return emitError("Level type mismatch between source/dest tensors");
1576
1577 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1578 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1579 return emitError("Crd/Pos width mismatch between source/dest tensors");
1580 }
1581
1582 if (srcStt.getElementType() != dstStt.getElementType())
1583 return emitError("Element type mismatch between source/dest tensors");
1584
1585 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1586 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1587 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1588 if (srcLvlSz != dstLvlSz) {
1589 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1590 // compatible to <3x4>? For now, we require all the level sizes to be
1591 // *exactly* matched for simplicity.
1592 return emitError("Level size mismatch between source/dest tensors");
1593 }
1594 }
1595
1596 return success();
1597}
1598
1599OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1600 if (getSource().getType() == getDest().getType())
1601 return getSource();
1602
1603 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1604 // A -> B, B -> A ==> A
1605 if (def.getSource().getType() == getDest().getType())
1606 return def.getSource();
1607 }
1608 return {};
1609}
1610
1611template <typename ToBufferOp>
1612static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1613 PropertyRef prop, RegionRange region,
1615 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1616 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1617 Type elemTp = nullptr;
1618 bool withStride = false;
1619 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1620 elemTp = stt.getPosType();
1621 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1622 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1623 elemTp = stt.getCrdType();
1624 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1625 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1626 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1627 elemTp = stt.getElementType();
1628 }
1629
1630 assert(elemTp && "unhandled operation.");
1631 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1632 bufShape.push_back(ShapedType::kDynamic);
1633
1634 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1635 stt.getContext(), ShapedType::kDynamic,
1636 {ShapedType::kDynamic})
1637 : StridedLayoutAttr();
1638 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1639 return success();
1640}
1641
1642LogicalResult ToPositionsOp::verify() {
1643 auto stt = getSparseTensorType(getTensor());
1644 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1645 return emitError("requested level is out of bounds");
1646 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1647 return emitError("unexpected type for positions");
1648 return success();
1649}
1650
1651LogicalResult
1652ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1653 ValueRange ops, DictionaryAttr attr,
1654 PropertyRef prop, RegionRange region,
1655 SmallVectorImpl<mlir::Type> &ret) {
1656 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1657}
1658
1659LogicalResult ToCoordinatesOp::verify() {
1660 auto stt = getSparseTensorType(getTensor());
1661 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1662 return emitError("requested level is out of bounds");
1663 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1664 return emitError("unexpected type for coordinates");
1665 return success();
1666}
1667
1668LogicalResult
1669ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1670 ValueRange ops, DictionaryAttr attr,
1671 PropertyRef prop, RegionRange region,
1672 SmallVectorImpl<mlir::Type> &ret) {
1673 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1674}
1675
1676LogicalResult ToCoordinatesBufferOp::verify() {
1677 auto stt = getSparseTensorType(getTensor());
1678 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1679 return emitError("expected sparse tensor with a COO region");
1680 return success();
1681}
1682
1683LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1684 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1685 DictionaryAttr attr, PropertyRef prop, RegionRange region,
1686 SmallVectorImpl<mlir::Type> &ret) {
1687 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1688 ret);
1689}
1690
1691LogicalResult ToValuesOp::verify() {
1692 auto stt = getSparseTensorType(getTensor());
1693 auto mtp = getMemRefType(getResult());
1694 if (stt.getElementType() != mtp.getElementType())
1695 return emitError("unexpected mismatch in element types");
1696 return success();
1697}
1698
1699LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1700 std::optional<Location> loc,
1701 ValueRange ops, DictionaryAttr attr,
1702 PropertyRef prop, RegionRange region,
1703 SmallVectorImpl<mlir::Type> &ret) {
1704 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1705}
1706
1707LogicalResult ToSliceOffsetOp::verify() {
1708 auto rank = getSlice().getType().getRank();
1709 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1710 return emitError("requested dimension out of bound");
1711 return success();
1712}
1713
1714LogicalResult ToSliceStrideOp::verify() {
1715 auto rank = getSlice().getType().getRank();
1716 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1717 return emitError("requested dimension out of bound");
1718 return success();
1719}
1720
1721LogicalResult GetStorageSpecifierOp::verify() {
1722 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1723 getSpecifier(), getOperation());
1724}
1725
1726template <typename SpecifierOp>
1727static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1728 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1729}
1730
1731OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1732 const StorageSpecifierKind kind = getSpecifierKind();
1733 const auto lvl = getLevel();
1734 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1735 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1736 return op.getValue();
1737 return {};
1738}
1739
1740LogicalResult SetStorageSpecifierOp::verify() {
1741 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1742 getSpecifier(), getOperation());
1743}
1744
1745template <class T>
1746static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1747 const char *regionName,
1748 TypeRange inputTypes, Type outputType) {
1749 unsigned numArgs = region.getNumArguments();
1750 unsigned expectedNum = inputTypes.size();
1751 if (numArgs != expectedNum)
1752 return op->emitError() << regionName << " region must have exactly "
1753 << expectedNum << " arguments";
1754
1755 for (unsigned i = 0; i < numArgs; i++) {
1756 Type typ = region.getArgument(i).getType();
1757 if (typ != inputTypes[i])
1758 return op->emitError() << regionName << " region argument " << (i + 1)
1759 << " type mismatch";
1760 }
1761 Block &block = region.front();
1762 if (!block.mightHaveTerminator())
1763 return op->emitError() << regionName
1764 << " region must end with a terminator";
1765
1766 Operation *term = block.getTerminator();
1767 YieldOp yield = dyn_cast<YieldOp>(term);
1768 if (!yield)
1769 return op->emitError() << regionName
1770 << " region must end with sparse_tensor.yield";
1771 if (!yield.hasSingleResult() ||
1772 yield.getSingleResult().getType() != outputType)
1773 return op->emitError() << regionName << " region yield type mismatch";
1774
1775 return success();
1776}
1777
1778LogicalResult BinaryOp::verify() {
1779 NamedAttrList attrs = (*this)->getAttrs();
1780 Type leftType = getX().getType();
1781 Type rightType = getY().getType();
1782 Type outputType = getOutput().getType();
1783 Region &overlap = getOverlapRegion();
1784 Region &left = getLeftRegion();
1785 Region &right = getRightRegion();
1786
1787 // Check correct number of block arguments and return type for each
1788 // non-empty region.
1789 if (!overlap.empty()) {
1790 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1791 TypeRange{leftType, rightType}, outputType)))
1792 return failure();
1793 }
1794 if (!left.empty()) {
1795 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1796 outputType)))
1797 return failure();
1798 } else if (getLeftIdentity()) {
1799 if (leftType != outputType)
1800 return emitError("left=identity requires first argument to have the same "
1801 "type as the output");
1802 }
1803 if (!right.empty()) {
1804 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1805 outputType)))
1806 return failure();
1807 } else if (getRightIdentity()) {
1808 if (rightType != outputType)
1809 return emitError("right=identity requires second argument to have the "
1810 "same type as the output");
1811 }
1812 return success();
1813}
1814
1815LogicalResult UnaryOp::verify() {
1816 Type inputType = getX().getType();
1817 Type outputType = getOutput().getType();
1818
1819 // Check correct number of block arguments and return type for each
1820 // non-empty region.
1821 Region &present = getPresentRegion();
1822 if (!present.empty()) {
1823 if (failed(verifyNumBlockArgs(this, present, "present",
1824 TypeRange{inputType}, outputType)))
1825 return failure();
1826 }
1827 Region &absent = getAbsentRegion();
1828 if (!absent.empty()) {
1829 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1830 outputType)))
1831 return failure();
1832 // Absent branch can only yield invariant values.
1833 Block *absentBlock = &absent.front();
1834 Block *parent = getOperation()->getBlock();
1835 Value absentVal =
1836 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1837 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1838 if (arg.getOwner() == parent)
1839 return emitError("absent region cannot yield linalg argument");
1840 } else if (Operation *def = absentVal.getDefiningOp()) {
1841 if (!isa<arith::ConstantOp>(def) &&
1842 (def->getBlock() == absentBlock || def->getBlock() == parent))
1843 return emitError("absent region cannot yield locally computed value");
1844 }
1845 }
1846 return success();
1847}
1848
1849bool ConcatenateOp::needsExtraSort() {
1850 SparseTensorType dstStt = getSparseTensorType(*this);
1851 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1852 return false;
1853
1854 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1855 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1856 });
1857 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1858 // in all input/output buffers, and all input/output buffers have the same
1859 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1860 // CSC matrices along column).
1861 bool directLowerable =
1862 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1863 return !directLowerable;
1864}
1865
1866LogicalResult ConcatenateOp::verify() {
1867 const auto dstTp = getSparseTensorType(*this);
1868 const Dimension concatDim = getDimension();
1869 const Dimension dimRank = dstTp.getDimRank();
1870
1871 if (getInputs().size() <= 1)
1872 return emitError("Need at least two tensors to concatenate.");
1873
1874 if (concatDim >= dimRank)
1875 return emitError(llvm::formatv(
1876 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1877 concatDim, dimRank));
1878
1879 for (const auto &it : llvm::enumerate(getInputs())) {
1880 const auto i = it.index();
1881 const auto srcTp = getSparseTensorType(it.value());
1882 if (srcTp.hasDynamicDimShape())
1883 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1884 const Dimension srcDimRank = srcTp.getDimRank();
1885 if (srcDimRank != dimRank)
1886 return emitError(
1887 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1888 "from the output tensor (rank={2}).",
1889 i, srcDimRank, dimRank));
1890 }
1891
1892 for (Dimension d = 0; d < dimRank; d++) {
1893 const Size dstSh = dstTp.getDimShape()[d];
1894 if (d == concatDim) {
1895 if (ShapedType::isStatic(dstSh)) {
1896 // If we reach here, then all inputs have static shapes. So we
1897 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1898 // to avoid redundant assertions in the loop.
1899 Size sumSz = 0;
1900 for (const auto src : getInputs())
1901 sumSz += getSparseTensorType(src).getDimShape()[d];
1902 // If all dimension are statically known, the sum of all the input
1903 // dimensions should be equal to the output dimension.
1904 if (sumSz != dstSh)
1905 return emitError(
1906 "The concatenation dimension of the output tensor should be the "
1907 "sum of all the concatenation dimensions of the input tensors.");
1908 }
1909 } else {
1910 Size prev = dstSh;
1911 for (const auto src : getInputs()) {
1912 const auto sh = getSparseTensorType(src).getDimShape()[d];
1913 if (ShapedType::isStatic(prev) && sh != prev)
1914 return emitError("All dimensions (expect for the concatenating one) "
1915 "should be equal.");
1916 prev = sh;
1917 }
1918 }
1919 }
1920
1921 return success();
1922}
1923
1924void PushBackOp::build(OpBuilder &builder, OperationState &result,
1925 Value curSize, Value inBuffer, Value value) {
1926 build(builder, result, curSize, inBuffer, value, Value());
1927}
1928
1929LogicalResult PushBackOp::verify() {
1930 if (Value n = getN()) {
1931 std::optional<int64_t> nValue = getConstantIntValue(n);
1932 if (nValue && nValue.value() < 1)
1933 return emitOpError("n must be not less than 1");
1934 }
1935 return success();
1936}
1937
1938LogicalResult CompressOp::verify() {
1939 const auto stt = getSparseTensorType(getTensor());
1940 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1941 return emitOpError("incorrect number of coordinates");
1942 return success();
1943}
1944
1945void ForeachOp::build(
1946 OpBuilder &builder, OperationState &result, Value tensor,
1947 ValueRange initArgs, AffineMapAttr order,
1948 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1949 bodyBuilder) {
1950 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1951 // Builds foreach body.
1952 if (!bodyBuilder)
1953 return;
1954 const auto stt = getSparseTensorType(tensor);
1955 const Dimension dimRank = stt.getDimRank();
1956
1957 // Starts with `dimRank`-many coordinates.
1958 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1959 // Followed by one value.
1960 blockArgTypes.push_back(stt.getElementType());
1961 // Followed by the reduction variables.
1962 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1963
1964 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1965
1966 OpBuilder::InsertionGuard guard(builder);
1967 auto &region = *result.regions.front();
1968 Block *bodyBlock =
1969 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1970 bodyBuilder(builder, result.location,
1971 bodyBlock->getArguments().slice(0, dimRank),
1972 bodyBlock->getArguments()[dimRank],
1973 bodyBlock->getArguments().drop_front(dimRank + 1));
1974}
1975
1976LogicalResult ForeachOp::verify() {
1977 const auto t = getSparseTensorType(getTensor());
1978 const Dimension dimRank = t.getDimRank();
1979 const auto args = getBody()->getArguments();
1980
1981 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1982 return emitError("Level traverse order does not match tensor's level rank");
1983
1984 if (dimRank + 1 + getInitArgs().size() != args.size())
1985 return emitError("Unmatched number of arguments in the block");
1986
1987 if (getNumResults() != getInitArgs().size())
1988 return emitError("Mismatch in number of init arguments and results");
1989
1990 if (getResultTypes() != getInitArgs().getTypes())
1991 return emitError("Mismatch in types of init arguments and results");
1992
1993 // Cannot mark this const, because the getters aren't.
1994 auto yield = cast<YieldOp>(getBody()->getTerminator());
1995 if (yield.getNumOperands() != getNumResults() ||
1996 yield.getOperands().getTypes() != getResultTypes())
1997 return emitError("Mismatch in types of yield values and results");
1998
1999 const auto iTp = IndexType::get(getContext());
2000 for (Dimension d = 0; d < dimRank; d++)
2001 if (args[d].getType() != iTp)
2002 return emitError(
2003 llvm::formatv("Expecting Index type for argument at index {0}", d));
2004
2005 const auto elemTp = t.getElementType();
2006 const auto valueTp = args[dimRank].getType();
2007 if (elemTp != valueTp)
2008 return emitError(
2009 llvm::formatv("Unmatched element type between input tensor and "
2010 "block argument, expected:{0}, got: {1}",
2011 elemTp, valueTp));
2012 return success();
2013}
2014
2015OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2016 if (getSparseTensorEncoding(getInputCoo().getType()) ==
2017 getSparseTensorEncoding(getResultCoo().getType()))
2018 return getInputCoo();
2019
2020 return {};
2021}
2022
2023LogicalResult ReorderCOOOp::verify() {
2024 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2025 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2026
2027 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2028 return emitError("Expected COO sparse tensors only");
2029
2030 if (!srcStt.hasSameDimToLvl(dstStt))
2031 return emitError("Unmatched dim2lvl map between input and result COO");
2032
2033 if (srcStt.getPosType() != dstStt.getPosType() ||
2034 srcStt.getCrdType() != dstStt.getCrdType() ||
2035 srcStt.getElementType() != dstStt.getElementType())
2036 return emitError("Unmatched storage format between input and result COO");
2037
2038 return success();
2039}
2040
2041LogicalResult ReduceOp::verify() {
2042 Type inputType = getX().getType();
2043 Region &formula = getRegion();
2044 return verifyNumBlockArgs(this, formula, "reduce",
2045 TypeRange{inputType, inputType}, inputType);
2046}
2047
2048LogicalResult SelectOp::verify() {
2049 Builder b(getContext());
2050 Type inputType = getX().getType();
2051 Type boolType = b.getI1Type();
2052 Region &formula = getRegion();
2053 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2054 boolType);
2055}
2056
2057LogicalResult SortOp::verify() {
2058 AffineMap xPerm = getPermMap();
2059 uint64_t nx = xPerm.getNumDims();
2060 if (nx < 1)
2061 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2062
2063 if (!xPerm.isPermutation())
2064 return emitError(
2065 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2066
2067 // We can't check the size of the buffers when n or buffer dimensions aren't
2068 // compile-time constants.
2069 std::optional<int64_t> cn = getConstantIntValue(getN());
2070 if (!cn)
2071 return success();
2072
2073 // Verify dimensions.
2074 const auto checkDim = [&](Value v, Size minSize,
2075 const char *message) -> LogicalResult {
2076 const Size sh = getMemRefType(v).getShape()[0];
2077 if (ShapedType::isStatic(sh) && sh < minSize)
2078 return emitError(
2079 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2080 return success();
2081 };
2082 uint64_t n = cn.value();
2083 uint64_t ny = 0;
2084 if (auto nyAttr = getNyAttr())
2085 ny = nyAttr.getInt();
2086 if (failed(checkDim(getXy(), n * (nx + ny),
2087 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2088 return failure();
2089 for (Value opnd : getYs())
2090 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2091 return failure();
2092
2093 return success();
2094}
2095
2096//===----------------------------------------------------------------------===//
2097// Sparse Tensor Iteration Operations.
2098//===----------------------------------------------------------------------===//
2099
2100IterSpaceType IteratorType::getIterSpaceType() const {
2101 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2102 getHiLvl());
2103}
2104
2105IteratorType IterSpaceType::getIteratorType() const {
2106 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2107}
2108
2109/// Parses a level range in the form "$lo `to` $hi"
2110/// or simply "$lo" if $hi - $lo = 1
2111static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2112 Level &lvlHi) {
2113 if (parser.parseInteger(lvlLo))
2114 return failure();
2115
2116 if (succeeded(parser.parseOptionalKeyword("to"))) {
2117 if (parser.parseInteger(lvlHi))
2118 return failure();
2119 } else {
2120 lvlHi = lvlLo + 1;
2121 }
2122
2123 if (lvlHi <= lvlLo)
2124 return parser.emitError(parser.getNameLoc(),
2125 "expect larger level upper bound than lower bound");
2126
2127 return success();
2128}
2129
2130/// Parses a level range in the form "$lo `to` $hi"
2131/// or simply "$lo" if $hi - $lo = 1
2132static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2133 IntegerAttr &lvlHiAttr) {
2134 Level lvlLo, lvlHi;
2135 if (parseLevelRange(parser, lvlLo, lvlHi))
2136 return failure();
2137
2138 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2139 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2140 return success();
2141}
2142
2143/// Prints a level range in the form "$lo `to` $hi"
2144/// or simply "$lo" if $hi - $lo = 1
2145static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2146
2147 if (lo + 1 == hi)
2148 p << lo;
2149 else
2150 p << lo << " to " << hi;
2151}
2152
2153/// Prints a level range in the form "$lo `to` $hi"
2154/// or simply "$lo" if $hi - $lo = 1
2155static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2156 IntegerAttr lvlHi) {
2157 unsigned lo = lvlLo.getValue().getZExtValue();
2158 unsigned hi = lvlHi.getValue().getZExtValue();
2159 printLevelRange(p, lo, hi);
2160}
2161
2162/// Parses a list of `optional` defined list in the form of
2163/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2164/// corresponding value is not defined (e.g., to represent an undefined
2165/// coordinate in the sparse iteration space).
2166static ParseResult parseOptionalDefinedList(
2167 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2169 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2171 unsigned cnt = 0;
2172 ParseResult crdList =
2173 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2174 if (parser.parseOptionalKeyword("_")) {
2175 if (parser.parseArgument(definedArgs.emplace_back()))
2176 return failure();
2177 definedSet.set(cnt);
2178 }
2179 cnt += 1;
2180 return success();
2181 });
2182
2183 if (cnt > maxCnt)
2184 return parser.emitError(parser.getNameLoc(),
2185 "parsed more value than expected.");
2186
2187 if (failed(crdList)) {
2188 return parser.emitError(
2189 parser.getNameLoc(),
2190 "expecting SSA value or \"_\" for level coordinates");
2191 }
2192 assert(definedArgs.size() == definedSet.count());
2193 return success();
2194}
2195
2196static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2197 Block::BlockArgListType blocksArgs,
2198 I64BitSet definedSet) {
2199 if (definedSet.empty())
2200 return;
2201
2202 for (unsigned i = 0; i < size; i++) {
2203 if (definedSet[i]) {
2204 p << blocksArgs.front();
2205 blocksArgs = blocksArgs.drop_front();
2206 } else {
2207 p << "_";
2208 }
2209 if (i != size - 1)
2210 p << ", ";
2211 }
2212 assert(blocksArgs.empty());
2213}
2214
2215static ParseResult
2218 // Parse "at(%crd0, _, ...)"
2219 I64BitSet crdUsedLvlSet;
2220 if (succeeded(parser.parseOptionalKeyword("at")) &&
2221 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2222 return failure();
2223
2224 // Always use IndexType for the coordinate.
2225 for (auto &coord : coords)
2226 coord.type = parser.getBuilder().getIndexType();
2227
2228 // Set the CrdUsedLvl bitset.
2229 state.addAttribute("crdUsedLvls",
2230 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2231 return success();
2232}
2233
2234static ParseResult
2240
2241 // Parse "%iters, ... in %spaces, ..."
2242 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2243 parser.parseOperandList(spaces))
2244 return failure();
2245
2246 if (iterators.size() != spaces.size())
2247 return parser.emitError(
2248 parser.getNameLoc(),
2249 "mismatch in number of sparse iterators and sparse spaces");
2250
2252 if (failed(parseUsedCoordList(parser, state, coords)))
2253 return failure();
2254 size_t numCrds = coords.size();
2255
2256 // Parse "iter_args(%arg = %init, ...)"
2257 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2258 if (hasIterArgs)
2259 if (parser.parseAssignmentList(blockArgs, initArgs))
2260 return failure();
2261
2262 blockArgs.append(coords);
2263
2264 SmallVector<Type> iterSpaceTps;
2265 // parse ": sparse_tensor.iter_space -> ret"
2266 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2267 return failure();
2268 if (iterSpaceTps.size() != spaces.size())
2269 return parser.emitError(parser.getNameLoc(),
2270 "mismatch in number of iteration space operands "
2271 "and iteration space types");
2272
2273 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2274 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2275 if (!spaceTp)
2276 return parser.emitError(parser.getNameLoc(),
2277 "expected sparse_tensor.iter_space type for "
2278 "iteration space operands");
2279 it.type = spaceTp.getIteratorType();
2280 }
2281
2282 if (hasIterArgs)
2283 if (parser.parseArrowTypeList(state.types))
2284 return failure();
2285
2286 // Resolves input operands.
2287 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2288 state.operands))
2289 return failure();
2290
2291 if (hasIterArgs) {
2292 // Strip off leading args that used for coordinates.
2293 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2294 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2295 return parser.emitError(
2296 parser.getNameLoc(),
2297 "mismatch in number of iteration arguments and return values");
2298 }
2299
2300 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2301 it.type = tp;
2302 if (parser.resolveOperand(init, tp, state.operands))
2303 return failure();
2304 }
2305 }
2306 return success();
2307}
2308
2309static ParseResult
2311 SmallVectorImpl<Value> &spacesVals,
2313
2314 // Parse "(%spaces, ...)"
2317 return failure();
2318
2320 if (failed(parseUsedCoordList(parser, state, coords)))
2321 return failure();
2322 size_t numCrds = coords.size();
2323
2324 // Parse "iter_args(%arg = %init, ...)"
2326 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2327 if (hasIterArgs)
2328 if (parser.parseAssignmentList(blockArgs, initArgs))
2329 return failure();
2330 blockArgs.append(coords);
2331
2332 SmallVector<Type> iterSpaceTps;
2333 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2334 if (parser.parseColon() || parser.parseLParen() ||
2335 parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2336 return failure();
2337
2338 if (iterSpaceTps.size() != spaces.size())
2339 return parser.emitError(parser.getNameLoc(),
2340 "mismatch in number of iteration space operands "
2341 "and iteration space types");
2342
2343 if (hasIterArgs)
2344 if (parser.parseArrowTypeList(state.types))
2345 return failure();
2346
2347 // Resolves input sparse iteration spaces.
2348 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2349 spacesVals))
2350 return failure();
2351 state.operands.append(spacesVals);
2352
2353 if (hasIterArgs) {
2354 // Strip off trailing args that used for coordinates.
2355 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2356 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2357 return parser.emitError(
2358 parser.getNameLoc(),
2359 "mismatch in number of iteration arguments and return values");
2360 }
2361
2362 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2363 it.type = tp;
2364 if (parser.resolveOperand(init, tp, state.operands))
2365 return failure();
2366 }
2367 }
2368 return success();
2369}
2370
2371LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2372 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2373 DictionaryAttr attr, PropertyRef prop, RegionRange region,
2374 SmallVectorImpl<mlir::Type> &ret) {
2375
2376 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2377 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2378 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2379 adaptor.getHiLvl()));
2380 return success();
2381}
2382
2383LogicalResult ExtractIterSpaceOp::verify() {
2384 if (getLoLvl() >= getHiLvl())
2385 return emitOpError("expected smaller level low than level high");
2386
2387 TypedValue<IteratorType> pIter = getParentIter();
2388 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2389 return emitOpError(
2390 "parent iterator should be specified iff level lower bound equals 0");
2391 }
2392
2393 if (pIter) {
2394 IterSpaceType spaceTp = getExtractedSpace().getType();
2395 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2396 return emitOpError(
2397 "mismatch in parent iterator encoding and iteration space encoding.");
2398
2399 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2400 return emitOpError("parent iterator should be used to extract an "
2401 "iteration space from a consecutive level.");
2402 }
2403
2404 return success();
2405}
2406
2407LogicalResult ExtractValOp::verify() {
2408 auto stt = getSparseTensorType(getTensor());
2409 auto itTp = getIterator().getType();
2410
2411 if (stt.getEncoding() != itTp.getEncoding())
2412 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2413
2414 if (stt.getLvlRank() != itTp.getHiLvl())
2415 return emitOpError("must use last-level iterator to extract values. ");
2416
2417 return success();
2418}
2419
2420struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2422
2423 LogicalResult matchAndRewrite(IterateOp iterateOp,
2424 PatternRewriter &rewriter) const override {
2425 I64BitSet newUsedLvls(0);
2426 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2427 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2428 if (auto crd = iterateOp.getLvlCrd(i)) {
2429 if (crd->getUsers().empty())
2430 toRemove.set(crd->getArgNumber());
2431 else
2432 newUsedLvls.set(i);
2433 }
2434 }
2435
2436 // All coordinates are used.
2437 if (toRemove.none())
2438 return failure();
2439
2440 rewriter.startOpModification(iterateOp);
2441 iterateOp.setCrdUsedLvls(newUsedLvls);
2442 iterateOp.getBody()->eraseArguments(toRemove);
2443 rewriter.finalizeOpModification(iterateOp);
2444 return success();
2445 }
2446};
2447
2448void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2449 mlir::MLIRContext *context) {
2450 results.add<RemoveUnusedLvlCrds>(context);
2451}
2452
2453void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2454 Value iterSpace, ValueRange initArgs) {
2455 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2456 // All ones.
2457 I64BitSet set((1 << rank) - 1);
2458 return build(builder, odsState, iterSpace, initArgs, set);
2459}
2460
2461void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2462 Value iterSpace, ValueRange initArgs,
2463 I64BitSet crdUsedLvls) {
2464 OpBuilder::InsertionGuard guard(builder);
2465
2466 odsState.addOperands(iterSpace);
2467 odsState.addOperands(initArgs);
2468 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2469 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2470 Region *bodyRegion = odsState.addRegion();
2471 odsState.addTypes(initArgs.getTypes());
2472 Block *bodyBlock = builder.createBlock(bodyRegion);
2473
2474 // Starts with a list of user-provided loop arguments.
2475 for (Value v : initArgs)
2476 bodyBlock->addArgument(v.getType(), v.getLoc());
2477
2478 // Follows by a list of used coordinates.
2479 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2480 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2481
2482 // Ends with sparse iterator
2483 bodyBlock->addArgument(
2484 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2485 odsState.location);
2486}
2487
2488ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2489 OpAsmParser::Argument iterator;
2490 OpAsmParser::UnresolvedOperand iterSpace;
2491
2492 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2493 if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2494 return failure();
2495 if (iters.size() != 1)
2496 return parser.emitError(parser.getNameLoc(),
2497 "expected only one iterator/iteration space");
2498
2499 iterArgs.append(iters);
2500 Region *body = result.addRegion();
2501 if (parser.parseRegion(*body, iterArgs))
2502 return failure();
2503
2504 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2505
2506 // Parse the optional attribute list.
2507 if (parser.parseOptionalAttrDict(result.attributes))
2508 return failure();
2509
2510 return success();
2511}
2512
2513/// Prints the initialization list in the form of
2514/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2515/// where 'inner' values are assumed to be region arguments and 'outer' values
2516/// are regular SSA values.
2518 Block::BlockArgListType blocksArgs,
2519 ValueRange initializers,
2520 StringRef prefix = "") {
2521 assert(blocksArgs.size() == initializers.size() &&
2522 "expected same length of arguments and initializers");
2523 if (initializers.empty())
2524 return;
2525
2526 p << prefix << '(';
2527 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2528 p << std::get<0>(it) << " = " << std::get<1>(it);
2529 });
2530 p << ")";
2531}
2532
2533template <typename SparseLoopOp>
2534static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2535 if (op.getInitArgs().size() != op.getNumResults()) {
2536 return op.emitOpError(
2537 "mismatch in number of loop-carried values and defined values");
2538 }
2539 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2540 return op.emitOpError("required out-of-bound coordinates");
2541
2542 return success();
2543}
2544
2545LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2546LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2547
2548void IterateOp::print(OpAsmPrinter &p) {
2549 p << " " << getIterator() << " in " << getIterSpace();
2550 if (!getCrdUsedLvls().empty()) {
2551 p << " at(";
2552 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2553 p << ")";
2554 }
2555 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2556
2557 p << " : " << getIterSpace().getType() << " ";
2558 if (!getInitArgs().empty())
2559 p.printArrowTypeList(getInitArgs().getTypes());
2560
2561 p << " ";
2562 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2563 /*printBlockTerminators=*/!getInitArgs().empty());
2564}
2565
2566LogicalResult IterateOp::verifyRegions() {
2567 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2568 return emitOpError("mismatch in iterator and iteration space type");
2569 if (getNumRegionIterArgs() != getNumResults())
2570 return emitOpError(
2571 "mismatch in number of basic block args and defined values");
2572
2573 auto initArgs = getInitArgs();
2574 auto iterArgs = getRegionIterArgs();
2575 auto yieldVals = getYieldedValues();
2576 auto opResults = getResults();
2577 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2578 opResults.size()})) {
2579 return emitOpError() << "number mismatch between iter args and results.";
2580 }
2581
2582 for (auto [i, init, iter, yield, ret] :
2583 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2584 if (init.getType() != ret.getType())
2585 return emitOpError() << "types mismatch between " << i
2586 << "th iter operand and defined value";
2587 if (iter.getType() != ret.getType())
2588 return emitOpError() << "types mismatch between " << i
2589 << "th iter region arg and defined value";
2590 if (yield.getType() != ret.getType())
2591 return emitOpError() << "types mismatch between " << i
2592 << "th yield value and defined value";
2593 }
2594
2595 return success();
2596}
2597
2598/// OpInterfaces' methods implemented by IterateOp.
2599SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2600
2601MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2602 return getInitArgsMutable();
2603}
2604
2605Block::BlockArgListType IterateOp::getRegionIterArgs() {
2606 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2607}
2608
2609std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2610 return cast<sparse_tensor::YieldOp>(
2611 getRegion().getBlocks().front().getTerminator())
2612 .getResultsMutable();
2613}
2614
2615std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2616
2617OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2618 return getInitArgs();
2619}
2620
2621void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2622 SmallVectorImpl<RegionSuccessor> &regions) {
2623 // Both the operation itself and the region may be branching into the body
2624 // or back into the operation itself.
2625 regions.push_back(RegionSuccessor(&getRegion()));
2626 // It is possible for loop not to enter the body.
2627 regions.push_back(RegionSuccessor::parent());
2628}
2629
2630ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2631 return successor.isParent() ? ValueRange(getResults())
2632 : ValueRange(getRegionIterArgs());
2633}
2634
2635void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2636 ValueRange iterSpaces, ValueRange initArgs,
2637 unsigned numCases) {
2638 unsigned rank =
2639 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2640 // All ones.
2641 I64BitSet set((1 << rank) - 1);
2642 // Generates all-zero case bits (they only serve as placeholders), which are
2643 // supposed to be overriden later. We need to preallocate all the regions as
2644 // mlir::Region cannot be dynamically added later after the operation is
2645 // created.
2646 SmallVector<int64_t> caseBits(numCases, 0);
2647 ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2648 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2649 initArgs, set, cases,
2650 /*caseRegionsCount=*/numCases);
2651}
2652
2653ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2654
2655 SmallVector<Value> spaces;
2656 // The block argument list of each regions, it is arranged in the order of
2657 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2658 SmallVector<OpAsmParser::Argument> blockArgs;
2659 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2660 return failure();
2661
2662 result.addAttribute("operandSegmentSizes",
2664 {static_cast<int32_t>(spaces.size()),
2665 static_cast<int32_t>(result.types.size())}));
2666
2667 SmallVector<Attribute> cases;
2668 while (succeeded(parser.parseOptionalKeyword("case"))) {
2669 // Parse one region per case.
2670 I64BitSet definedItSet;
2671 SmallVector<OpAsmParser::Argument> definedIts;
2672 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2673 spaces.size(), OpAsmParser::Delimiter::None))
2674 return failure();
2675
2676 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2677
2678 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2679 // Resolve the iterator type based on the iteration space type.
2680 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2681 definedIts[i].type = spaceTp.getIteratorType();
2682 }
2683 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2684 Region *body = result.addRegion();
2685 if (parser.parseRegion(*body, definedIts))
2686 return failure();
2687
2688 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2689 }
2690
2691 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2692
2693 // Parse the optional attribute list.
2694 if (parser.parseOptionalAttrDict(result.attributes))
2695 return failure();
2696
2697 return success();
2698}
2699
2700void CoIterateOp::print(OpAsmPrinter &p) {
2701 p << " (";
2702 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2703 p << ")";
2704
2705 if (!getCrdUsedLvls().empty()) {
2706 p << " at(";
2707 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2708 p << ")";
2709 }
2710
2711 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2712
2713 p << " : (" << getIterSpaces().getTypes() << ")";
2714 if (!getInitArgs().empty())
2715 p.printArrowTypeList(getInitArgs().getTypes());
2716
2717 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2718 p.printNewline();
2719 p << "case ";
2720 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2721 getRegionDefinedSpace(idx));
2722 p << " ";
2723 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2724 /*printBlockTerminators=*/!getInitArgs().empty());
2725 }
2726}
2727
2728ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2729 return cast<sparse_tensor::YieldOp>(
2730 getRegion(regionIdx).getBlocks().front().getTerminator())
2731 .getResults();
2732}
2733
2734LogicalResult CoIterateOp::verifyRegions() {
2735 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2736 if (getNumRegionIterArgs() != getNumResults())
2737 return emitOpError(
2738 "mismatch in number of basic block args and defined values");
2739
2740 auto initArgs = getInitArgs();
2741 auto iterArgs = getRegionIterArgs(r);
2742 auto yieldVals = getYieldedValues(r);
2743 auto opResults = getResults();
2744 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2745 opResults.size()})) {
2746 return emitOpError()
2747 << "number mismatch between iter args and results on " << r
2748 << "th region";
2749 }
2750
2751 for (auto [i, init, iter, yield, ret] :
2752 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2753 if (init.getType() != ret.getType())
2754 return emitOpError()
2755 << "types mismatch between " << i
2756 << "th iter operand and defined value on " << r << "th region";
2757 if (iter.getType() != ret.getType())
2758 return emitOpError() << "types mismatch between " << i
2759 << "th iter region arg and defined value on " << r
2760 << "th region";
2761 if (yield.getType() != ret.getType())
2762 return emitOpError()
2763 << "types mismatch between " << i
2764 << "th yield value and defined value on " << r << "th region";
2765 }
2766 }
2767
2768 auto cases = getRegionDefinedSpaces();
2769 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2770 if (set.size() != getNumRegions())
2771 return emitOpError("contains duplicated cases.");
2772
2773 return success();
2774}
2775
2776SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2777 SmallVector<Region *> ret;
2778 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2779 for (Region &r : getCaseRegions())
2780 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2781 ret.push_back(&r);
2782
2783 return ret;
2784}
2785
2786//===----------------------------------------------------------------------===//
2787// Sparse Tensor Dialect Setups.
2788//===----------------------------------------------------------------------===//
2789
2790/// Materialize a single constant operation from a given attribute value with
2791/// the desired resultant type.
2792Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2793 Attribute value, Type type,
2794 Location loc) {
2795 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2796 return op;
2797 return nullptr;
2798}
2799
2800void SparseTensorDialect::initialize() {
2801 addAttributes<
2802#define GET_ATTRDEF_LIST
2803#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2804 >();
2805 addTypes<
2806#define GET_TYPEDEF_LIST
2807#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2808 >();
2809 addOperations<
2810#define GET_OP_LIST
2811#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2812 >();
2813 declarePromisedInterfaces<
2814 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2815 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2816 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2817}
2818
2819#define GET_OP_CLASSES
2820#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2821
2822#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, PropertyRef prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition Enums.h:402
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
bool isOrderedLT(LevelType lt)
Definition Enums.h:425
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
Definition Enums.h:442
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition Enums.h:413
uint64_t getM(LevelType lt)
Definition Enums.h:443
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition Enums.h:299