MLIR 23.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
17
22#include "mlir/IR/Builders.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50static llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144 return;
145}
146
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
156
157 const Type specType = StorageSpecifierType::get(stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164
165 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
184 unsigned numFields = 0;
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
194 unsigned numFields = 0; // one value memref
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl) const {
210 unsigned stride = 1;
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(getOffset());
269 os << ", ";
270 os << getStaticString(getSize());
271 os << ", ";
272 os << getStaticString(getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(printer.getStream());
278}
279
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 parser.getCurrentLocation(),
287 "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(parser.parseLParen()) ||
302 failed(parseOptionalStaticSlice(offset, parser)) ||
303 failed(parser.parseComma()) ||
304 failed(parseOptionalStaticSlice(size, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(stride, parser)) ||
307 failed(parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311 offset, size, stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(0, 0);
354}
355
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
362}
363
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365 return withExplicitVal(Attribute());
366}
367
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
374}
375
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377 return withImplicitVal(Attribute());
378}
379
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385}
386
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389}
390
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
395}
396
397bool SparseTensorEncodingAttr::isAllDense() const {
398 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399}
400
401bool SparseTensorEncodingAttr::isAllOrdered() const {
402 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403}
404
405Type SparseTensorEncodingAttr::getCrdElemType() const {
406 if (!getImpl())
407 return nullptr;
408 if (getCrdWidth())
409 return IntegerType::get(getContext(), getCrdWidth());
410 return IndexType::get(getContext());
411}
412
413Type SparseTensorEncodingAttr::getPosElemType() const {
414 if (!getImpl())
415 return nullptr;
416 if (getPosWidth())
417 return IntegerType::get(getContext(), getPosWidth());
418 return IndexType::get(getContext());
419}
420
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape) const {
423 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424 return MemRefType::get(shape, getCrdElemType());
425}
426
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape) const {
429 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430 return MemRefType::get(shape, getPosElemType());
431}
432
433bool SparseTensorEncodingAttr::isIdentity() const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435}
436
437bool SparseTensorEncodingAttr::isPermutation() const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439}
440
441Dimension SparseTensorEncodingAttr::getDimRank() const {
442 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445}
446
447Level SparseTensorEncodingAttr::getLvlRank() const {
448 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
450}
451
452LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453 if (!getImpl())
454 return LevelFormat::Batch;
455 assert(l < getLvlRank() && "Level is out of bounds");
456 return getLvlTypes()[l];
457}
458
459bool SparseTensorEncodingAttr::isSlice() const {
460 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
462}
463
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466 assert(isSlice() && "Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() && "Dimension is out of bounds");
469 return dimSlices[dim];
470}
471
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474 return getDimSlice(dim).getStaticOffset();
475}
476
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479 return getDimSlice(dim).getStaticStride();
480}
481
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484 return getStaticDimSliceOffset(toDim(*this, lvl));
485}
486
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489 return getStaticDimSliceStride(toDim(*this, lvl));
490}
491
492SmallVector<int64_t>
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir) const {
495 if (isIdentity())
496 return SmallVector<int64_t>(srcShape);
497
498 SmallVector<int64_t> ret;
499 unsigned rank =
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501 ret.reserve(rank);
502
503 if (isPermutation()) {
504 for (unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506 : toLvl(*this, r);
507 ret.push_back(srcShape[trans]);
508 }
509 return ret;
510 }
511
512 // Handle non-permutation maps.
513 AffineMap transMap =
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
520 // Push back the max coordinate for the given dimension/level size.
521 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
522 } else {
523 // A dynamic size, use a AffineDimExpr to symbolize the value.
524 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
525 }
526 };
527
528 for (AffineExpr exp : transMap.getResults()) {
529 // Do constant propagation on the affine map.
530 AffineExpr evalExp =
531 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
532 // use llvm namespace here to avoid ambiguity
533 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534 ret.push_back(c.getValue() + 1);
535 } else {
536 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
537 mod && mod.getKind() == AffineExprKind::Mod) {
538 // We can still infer a static bound for expressions in form
539 // "d % constant" since d % constant \in [0, constant).
540 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541 ret.push_back(bound.getValue());
542 continue;
543 }
544 }
545 ret.push_back(ShapedType::kDynamic);
546 }
547 }
548 assert(ret.size() == rank);
549 return ret;
550}
551
553SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
554 ValueRange crds,
555 CrdTransDirectionKind dir) const {
556 if (!getImpl())
557 return crds;
558
559 SmallVector<Type> retType(
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
561 builder.getIndexType());
562 auto transOp =
563 CrdTranslateOp::create(builder, loc, retType, crds, dir, *this);
564 return transOp.getOutCrds();
565}
566
567Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
568 // Open "<{" part.
569 if (failed(parser.parseLess()))
570 return {};
571 if (failed(parser.parseLBrace()))
572 return {};
573
574 // Process the data from the parsed dictionary value into struct-like data.
575 SmallVector<LevelType> lvlTypes;
576 SmallVector<SparseTensorDimSliceAttr> dimSlices;
577 AffineMap dimToLvl = {};
578 AffineMap lvlToDim = {};
579 unsigned posWidth = 0;
580 unsigned crdWidth = 0;
581 Attribute explicitVal;
582 Attribute implicitVal;
583 StringRef attrName;
584 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
585 "explicitVal", "implicitVal"};
586 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
587 // Detect admissible keyword.
588 auto *it = find(keys, attrName);
589 if (it == keys.end()) {
590 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
591 return {};
592 }
593 unsigned keyWordIndex = it - keys.begin();
594 // Consume the `=` after keys
595 if (failed(parser.parseEqual()))
596 return {};
597 // Dispatch on keyword.
598 switch (keyWordIndex) {
599 case 0: { // map
600 ir_detail::DimLvlMapParser cParser(parser);
601 auto res = cParser.parseDimLvlMap();
602 if (failed(res))
603 return {};
604 const auto &dlm = *res;
605
606 const Level lvlRank = dlm.getLvlRank();
607 for (Level lvl = 0; lvl < lvlRank; lvl++)
608 lvlTypes.push_back(dlm.getLvlType(lvl));
609
610 const Dimension dimRank = dlm.getDimRank();
611 for (Dimension dim = 0; dim < dimRank; dim++)
612 dimSlices.push_back(dlm.getDimSlice(dim));
613 // NOTE: the old syntax requires an all-or-nothing approach to
614 // `dimSlices`; therefore, if any slice actually exists then we need
615 // to convert null-DSA into default/nop DSA.
616 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617 return static_cast<bool>(slice.getImpl());
618 };
619 if (llvm::any_of(dimSlices, isDefined)) {
620 const auto defaultSlice =
621 SparseTensorDimSliceAttr::get(parser.getContext());
622 for (Dimension dim = 0; dim < dimRank; dim++)
623 if (!isDefined(dimSlices[dim]))
624 dimSlices[dim] = defaultSlice;
625 } else {
626 dimSlices.clear();
627 }
628
629 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
630 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
631 break;
632 }
633 case 1: { // posWidth
634 Attribute attr;
635 if (failed(parser.parseAttribute(attr)))
636 return {};
637 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
638 if (!intAttr) {
639 parser.emitError(parser.getNameLoc(),
640 "expected an integral position bitwidth");
641 return {};
642 }
643 posWidth = intAttr.getInt();
644 break;
645 }
646 case 2: { // crdWidth
647 Attribute attr;
648 if (failed(parser.parseAttribute(attr)))
649 return {};
650 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
651 if (!intAttr) {
652 parser.emitError(parser.getNameLoc(),
653 "expected an integral index bitwidth");
654 return {};
655 }
656 crdWidth = intAttr.getInt();
657 break;
658 }
659 case 3: { // explicitVal
660 Attribute attr;
661 if (failed(parser.parseAttribute(attr)))
662 return {};
663 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
664 explicitVal = result;
665 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
666 explicitVal = result;
667 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668 explicitVal = result;
669 } else {
670 parser.emitError(parser.getNameLoc(),
671 "expected a numeric value for explicitVal");
672 return {};
673 }
674 break;
675 }
676 case 4: { // implicitVal
677 Attribute attr;
678 if (failed(parser.parseAttribute(attr)))
679 return {};
680 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
681 implicitVal = result;
682 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
683 implicitVal = result;
684 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685 implicitVal = result;
686 } else {
687 parser.emitError(parser.getNameLoc(),
688 "expected a numeric value for implicitVal");
689 return {};
690 }
691 break;
692 }
693 } // switch
694 // Only last item can omit the comma.
695 if (parser.parseOptionalComma().failed())
696 break;
697 }
698
699 // Close "}>" part.
700 if (failed(parser.parseRBrace()))
701 return {};
702 if (failed(parser.parseGreater()))
703 return {};
704
705 // Construct struct-like storage for attribute.
706 if (!lvlToDim || lvlToDim.isEmpty()) {
707 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
708 }
709 return parser.getChecked<SparseTensorEncodingAttr>(
710 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711 explicitVal, implicitVal, dimSlices);
712}
713
714void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
715 auto map = static_cast<AffineMap>(getDimToLvl());
716 // Empty affine map indicates identity map
717 if (!map)
718 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
719 printer << "<{ map = ";
720 printSymbols(map, printer);
721 printer << '(';
722 printDimensions(map, printer, getDimSlices());
723 printer << ") -> (";
724 printLevels(map, printer, getLvlTypes());
725 printer << ')';
726 // Print remaining members only for non-default values.
727 if (getPosWidth())
728 printer << ", posWidth = " << getPosWidth();
729 if (getCrdWidth())
730 printer << ", crdWidth = " << getCrdWidth();
731 if (getExplicitVal()) {
732 printer << ", explicitVal = " << getExplicitVal();
733 }
734 if (getImplicitVal())
735 printer << ", implicitVal = " << getImplicitVal();
736 printer << " }>";
737}
738
739void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
740 AsmPrinter &printer) const {
741 if (map.getNumSymbols() == 0)
742 return;
743 printer << '[';
744 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
745 printer << 's' << i << ", ";
746 if (map.getNumSymbols() >= 1)
747 printer << 's' << map.getNumSymbols() - 1;
748 printer << ']';
749}
750
751void SparseTensorEncodingAttr::printDimensions(
752 AffineMap &map, AsmPrinter &printer,
753 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
754 if (!dimSlices.empty()) {
755 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
756 printer << 'd' << i << " : " << dimSlices[i] << ", ";
757 if (map.getNumDims() >= 1) {
758 printer << 'd' << map.getNumDims() - 1 << " : "
759 << dimSlices[map.getNumDims() - 1];
760 }
761 } else {
762 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
763 printer << 'd' << i << ", ";
764 if (map.getNumDims() >= 1)
765 printer << 'd' << map.getNumDims() - 1;
766 }
767}
768
769void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
770 ArrayRef<LevelType> lvlTypes) const {
771 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
772 map.getResult(i).print(printer.getStream());
773 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
774 }
775 if (map.getNumResults() >= 1) {
776 auto lastIndex = map.getNumResults() - 1;
777 map.getResult(lastIndex).print(printer.getStream());
778 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
779 }
780}
781
782LogicalResult SparseTensorEncodingAttr::verify(
783 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
784 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
785 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
786 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
787 if (!acceptBitWidth(posWidth))
788 return emitError() << "unexpected position bitwidth: " << posWidth;
789 if (!acceptBitWidth(crdWidth))
790 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
791
792 // Verify every COO segment.
793 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
794 while (it != lvlTypes.end()) {
795 if (it == lvlTypes.begin() ||
797 return emitError() << "expected compressed or loose_compressed level "
798 "before singleton level";
799
800 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
801 if (!std::all_of(it, curCOOEnd, isSingletonLT))
802 return emitError() << "expected all singleton lvlTypes "
803 "following a singleton level";
804 // We can potentially support mixed SoA/AoS singleton levels.
805 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
806 return it->isa<LevelPropNonDefault::SoA>() ==
808 })) {
809 return emitError() << "expected all singleton lvlTypes stored in the "
810 "same memory layout (SoA vs AoS).";
811 }
812 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
813 }
814
815 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
816 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
817 return emitError() << "Batch lvlType can only be leading levels.";
818
819 // SoA property can only be applied on singleton level.
820 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
821 return lt.isa<LevelPropNonDefault::SoA>();
822 });
823 if (llvm::any_of(soaLvls, [](LevelType lt) {
824 return !lt.isa<LevelFormat::Singleton>();
825 })) {
826 return emitError() << "SoA is only applicable to singleton lvlTypes.";
827 }
828
829 // Dense levels cannot follow a non-unique level. The iteration model for
830 // dense levels requires exactly one parent position to linearize into a
831 // contiguous range, but a non-unique parent provides two cursor values
832 // (segment start and end), which the dense level cannot handle.
833 for (auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
834 if (isDenseLT(lt) && !isUniqueLT(lvlTypes[i - 1]))
835 return emitError() << "dense level cannot follow a non-unique level";
836 }
837
838 // TODO: audit formats that actually are supported by backend.
839 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
840 it != std::end(lvlTypes)) {
841 if (it != lvlTypes.end() - 1)
842 return emitError() << "expected n_out_of_m to be the last level type";
843 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
844 return emitError() << "expected all dense lvlTypes "
845 "before a n_out_of_m level";
846 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
847 if (!isBlockSparsity(dimToLvl)) {
848 return emitError()
849 << "expected 1xm block structure for n_out_of_m level";
850 }
851 auto sizes = getBlockSize(dimToLvl);
852 unsigned coefficient = 0;
853 for (const auto &elem : sizes) {
854 if (elem != 0) {
855 if (elem != coefficient && coefficient != 0) {
856 return emitError() << "expected only one blocked level "
857 "with the same coefficients";
858 }
859 coefficient = elem;
860 }
861 }
862 if (coefficient != getM(*it)) {
863 return emitError() << "expected coeffiencts of Affine expressions "
864 "to be equal to m of n_out_of_m level";
865 }
866 }
867 }
868 // Before we can check that the level-rank is consistent/coherent
869 // across all fields, we need to define it. The source-of-truth for
870 // the `getLvlRank` method is the length of the level-types array,
871 // since it must always be provided and have full rank; therefore we
872 // use that same source-of-truth here.
873 const Level lvlRank = lvlTypes.size();
874 if (lvlRank == 0)
875 return emitError() << "expected a non-empty array for lvlTypes";
876 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
877 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
878 if (dimToLvl) {
879 if (dimToLvl.getNumResults() != lvlRank)
880 return emitError()
881 << "level-rank mismatch between dimToLvl and lvlTypes: "
882 << dimToLvl.getNumResults() << " != " << lvlRank;
883 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
884 // Symbols can't be inferred but are acceptable.
885 if (!inferRes && dimToLvl.getNumSymbols() == 0)
886 return emitError() << "failed to infer lvlToDim from dimToLvl";
887 if (lvlToDim && (inferRes != lvlToDim))
888 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
889 if (dimRank > lvlRank)
890 return emitError() << "unexpected dimToLvl mapping from " << dimRank
891 << " to " << lvlRank;
892 }
893 if (!dimSlices.empty()) {
894 if (dimSlices.size() != dimRank)
895 return emitError()
896 << "dimension-rank mismatch between dimSlices and dimToLvl: "
897 << dimSlices.size() << " != " << dimRank;
898 // Compiler support for `dimSlices` currently requires that the two
899 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
900 if (dimRank != lvlRank)
901 return emitError()
902 << "dimSlices expected dimension-rank to match level-rank: "
903 << dimRank << " != " << lvlRank;
904 }
905 return success();
906}
907
908LogicalResult SparseTensorEncodingAttr::verifyEncoding(
909 ArrayRef<Size> dimShape, Type elementType,
910 function_ref<InFlightDiagnostic()> emitError) const {
911 // Check structural integrity. In particular, this ensures that the
912 // level-rank is coherent across all the fields.
913 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
914 getPosWidth(), getCrdWidth(), getExplicitVal(),
915 getImplicitVal(), getDimSlices())))
916 return failure();
917 // Check integrity with tensor type specifics. In particular, we
918 // need only check that the dimension-rank of the tensor agrees with
919 // the dimension-rank of the encoding.
920 const Dimension dimRank = dimShape.size();
921 if (dimRank == 0)
922 return emitError() << "expected non-scalar sparse tensor";
923 if (getDimRank() != dimRank)
924 return emitError()
925 << "dimension-rank mismatch between encoding and tensor shape: "
926 << getDimRank() << " != " << dimRank;
927 if (auto expVal = getExplicitVal()) {
928 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
929 if (attrType != elementType) {
930 return emitError() << "explicit value type mismatch between encoding and "
931 << "tensor element type: " << attrType
932 << " != " << elementType;
933 }
934 }
935 if (auto impVal = getImplicitVal()) {
936 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
937 if (attrType != elementType) {
938 return emitError() << "implicit value type mismatch between encoding and "
939 << "tensor element type: " << attrType
940 << " != " << elementType;
941 }
942 // Currently, we only support zero as the implicit value.
943 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
944 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
945 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
946 if ((impFVal && impFVal.getValue().isNonZero()) ||
947 (impIntVal && !impIntVal.getValue().isZero()) ||
948 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
949 impComplexVal.getReal().isNonZero()))) {
950 return emitError() << "implicit value must be zero";
951 }
952 }
953 return success();
954}
955
956Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
957 SmallVector<COOSegment> coo = getCOOSegments();
958 assert(coo.size() == 1 || coo.empty());
959 if (!coo.empty() && coo.front().isAoS()) {
960 return coo.front().lvlRange.first;
961 }
962 return getLvlRank();
963}
964
965SmallVector<COOSegment>
966mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
967 SmallVector<COOSegment> ret;
968 if (getLvlRank() <= 1)
969 return ret;
970
971 ArrayRef<LevelType> lts = getLvlTypes();
972 Level l = 0;
973 while (l < getLvlRank()) {
974 auto lt = lts[l];
976 auto cur = lts.begin() + l;
977 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
978 return !lt.isa<LevelFormat::Singleton>();
979 });
980 unsigned cooLen = std::distance(cur, end);
981 if (cooLen > 1) {
982 // To support mixed SoA/AoS COO, we should break the segment when the
983 // storage scheme changes, for now we faithfully assume that all
984 // consecutive singleton levels have the same storage format as verified
985 // STEA.
986 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
987 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
988 }
989 l += cooLen;
990 } else {
991 l++;
992 }
993 }
994 return ret;
995}
996
997//===----------------------------------------------------------------------===//
998// SparseTensorType Methods.
999//===----------------------------------------------------------------------===//
1000
1002 bool isUnique) const {
1003 if (!hasEncoding())
1004 return false;
1005 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1006 return false;
1007 for (Level l = startLvl + 1; l < lvlRank; ++l)
1008 if (!isSingletonLvl(l))
1009 return false;
1010 // If isUnique is true, then make sure that the last level is unique,
1011 // that is, when lvlRank == 1, the only compressed level is unique,
1012 // and when lvlRank > 1, the last singleton is unique.
1013 return !isUnique || isUniqueLvl(lvlRank - 1);
1014}
1015
1016RankedTensorType
1018 SmallVector<LevelType> lvlTypes;
1019 lvlTypes.reserve(lvlRank);
1020 // A non-unique compressed level at beginning (unless this is
1021 // also the last level, then it is unique).
1022 lvlTypes.push_back(
1023 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1024 if (lvlRank > 1) {
1025 // Followed by n-2 non-unique singleton levels.
1026 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1027 *buildLevelType(LevelFormat::Singleton, ordered, false));
1028 // Ends by a unique singleton level.
1029 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1030 }
1031 auto enc = SparseTensorEncodingAttr::get(
1032 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1034 return RankedTensorType::get(getDimShape(), getElementType(), enc);
1035}
1036
1037//===----------------------------------------------------------------------===//
1038// Convenience Methods.
1039//===----------------------------------------------------------------------===//
1040
1041SparseTensorEncodingAttr
1043 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1044 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1045 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1046 return mdtp.getEncoding();
1047 return nullptr;
1048}
1049
1051 MLIRContext *context) {
1052 auto map = static_cast<AffineMap>(dimToLvl);
1053 AffineMap lvlToDim;
1054 // Return an empty lvlToDim when inference is not successful.
1055 if (!map || map.getNumSymbols() != 0) {
1056 lvlToDim = AffineMap();
1057 } else if (map.isPermutation()) {
1058 lvlToDim = inversePermutation(map);
1059 } else if (isBlockSparsity(map)) {
1060 lvlToDim = inverseBlockSparsity(map, context);
1061 }
1062 return lvlToDim;
1063}
1064
1066 MLIRContext *context) {
1067 SmallVector<AffineExpr> lvlExprs;
1068 auto numLvls = dimToLvl.getNumResults();
1069 lvlExprs.reserve(numLvls);
1070 // lvlExprComponents stores information of the floordiv and mod operations
1071 // applied to the same dimension, so as to build the lvlToDim map.
1072 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1073 for (unsigned i = 0, n = numLvls; i < n; i++) {
1074 auto result = dimToLvl.getResult(i);
1075 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1076 if (result.getKind() == AffineExprKind::FloorDiv) {
1077 // Position of the dimension in dimToLvl.
1078 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1079 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1080 "expected only one floordiv for each dimension");
1081 SmallVector<AffineExpr, 3> components;
1082 // Level variable for floordiv.
1083 components.push_back(getAffineDimExpr(i, context));
1084 // Multiplier.
1085 components.push_back(binOp.getRHS());
1086 // Map key is the position of the dimension.
1087 lvlExprComponents[pos] = components;
1088 } else if (result.getKind() == AffineExprKind::Mod) {
1089 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1090 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1091 "expected floordiv before mod");
1092 // Add level variable for mod to the same vector
1093 // of the corresponding floordiv.
1094 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1095 } else {
1096 assert(false && "expected floordiv or mod");
1097 }
1098 } else {
1099 lvlExprs.push_back(getAffineDimExpr(i, context));
1100 }
1101 }
1102 // Build lvlExprs from lvlExprComponents.
1103 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1104 // would be [il, 2, ii]. It could be used to build the AffineExpr
1105 // i = il * 2 + ii in lvlToDim.
1106 for (auto &components : lvlExprComponents) {
1107 assert(components.second.size() == 3 &&
1108 "expected 3 components to build lvlExprs");
1109 auto mulOp = getAffineBinaryOpExpr(
1110 AffineExprKind::Mul, components.second[0], components.second[1]);
1111 auto addOp =
1112 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1113 lvlExprs.push_back(addOp);
1114 }
1115 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1116}
1117
1119 assert(isBlockSparsity(dimToLvl) &&
1120 "expected dimToLvl to be block sparsity for calling getBlockSize");
1121 SmallVector<unsigned> blockSize;
1122 for (auto result : dimToLvl.getResults()) {
1123 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1124 if (result.getKind() == AffineExprKind::Mod) {
1125 blockSize.push_back(
1126 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1127 }
1128 } else {
1129 blockSize.push_back(0);
1130 }
1131 }
1132 return blockSize;
1133}
1134
1136 if (!dimToLvl)
1137 return false;
1138 std::map<unsigned, int64_t> coeffientMap;
1139 bool hasBlock = false;
1140 for (auto result : dimToLvl.getResults()) {
1141 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1142 // Check for "dim op const".
1143 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1144 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1145 if (!dimOp || !conOp || conOp.getValue() <= 0)
1146 return false;
1147 // Inspect "dim / const" or "dim % const".
1148 auto pos = dimOp.getPosition();
1149 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1150 // Expect only one floordiv for each dimension.
1151 auto [it, inserted] = coeffientMap.try_emplace(pos);
1152 if (!inserted)
1153 return false;
1154 // Record coefficient of the floordiv.
1155 it->second = conOp.getValue();
1156 } else if (binOp.getKind() == AffineExprKind::Mod) {
1157 // Expect floordiv before mod.
1158 auto it = coeffientMap.find(pos);
1159 if (it == coeffientMap.end())
1160 return false;
1161 // Expect mod to have the same coefficient as floordiv.
1162 if (conOp.getValue() != it->second)
1163 return false;
1164 hasBlock = true;
1165 } else {
1166 return false;
1167 }
1168 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1169 auto pos = dimOp.getPosition();
1170 // Expect dim to be unset.
1171 if (!coeffientMap.try_emplace(pos, 0).second)
1172 return false;
1173 } else {
1174 return false;
1175 }
1176 }
1177 return hasBlock;
1178}
1179
1181 auto hasNonIdentityMap = [](Value v) {
1182 auto stt = tryGetSparseTensorType(v);
1183 return stt && !stt->isIdentity();
1184 };
1185
1186 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1187 llvm::any_of(op->getResults(), hasNonIdentityMap);
1188}
1189
1190Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1191 if (enc) {
1192 assert(enc.isPermutation() && "Non permutation map not supported");
1193 if (const auto dimToLvl = enc.getDimToLvl())
1194 return dimToLvl.getDimPosition(l);
1195 }
1196 return l;
1197}
1198
1199Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1200 if (enc) {
1201 assert(enc.isPermutation() && "Non permutation map not supported");
1202 if (const auto lvlToDim = enc.getLvlToDim())
1203 return lvlToDim.getDimPosition(d);
1204 }
1205 return d;
1206}
1207
1208/// We normalized sparse tensor encoding attribute by always using
1209/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1210/// as other variants) lead to the same storage specifier type, and stripping
1211/// irrelevant fields that do not alter the sparse tensor memory layout.
1212static SparseTensorEncodingAttr
1213getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1215 for (auto lt : enc.getLvlTypes())
1216 lts.push_back(lt.stripStorageIrrelevantProperties());
1217
1218 return SparseTensorEncodingAttr::get(
1219 enc.getContext(), lts,
1220 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1221 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1222 // Always use `index` for memSize and lvlSize instead of reusing
1223 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1224 // value for different bitwidth, it also avoids casting between index and
1225 // integer (returned by DimOp)
1226 0, 0,
1227 Attribute(), // explicitVal (irrelevant to storage specifier)
1228 Attribute(), // implicitVal (irrelevant to storage specifier)
1229 enc.getDimSlices());
1230}
1231
1232StorageSpecifierType
1233StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1234 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1235}
1236
1237StorageSpecifierType
1238StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1239 MLIRContext *ctx,
1240 SparseTensorEncodingAttr encoding) {
1241 return Base::getChecked(emitError, ctx,
1243}
1244
1245//===----------------------------------------------------------------------===//
1246// SparseTensorDialect Operations.
1247//===----------------------------------------------------------------------===//
1248
1249static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1250 return success(lvl < getSparseTensorType(tensor).getLvlRank());
1251}
1252
1253static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1254 const Type etp = getMemRefType(mem).getElementType();
1255 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1256}
1257
1258static LogicalResult verifySparsifierGetterSetter(
1259 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1261 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1262 return op->emitError(
1263 "redundant level argument for querying value memory size");
1264 }
1265
1266 const auto enc = md.getType().getEncoding();
1267 const Level lvlRank = enc.getLvlRank();
1268
1269 if (mdKind == StorageSpecifierKind::DimOffset ||
1270 mdKind == StorageSpecifierKind::DimStride)
1271 if (!enc.isSlice())
1272 return op->emitError("requested slice data on non-slice tensor");
1273
1274 if (mdKind != StorageSpecifierKind::ValMemSize) {
1275 if (!lvl)
1276 return op->emitError("missing level argument");
1277
1278 const Level l = lvl.value();
1279 if (l >= lvlRank)
1280 return op->emitError("requested level is out of bounds");
1281
1282 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1283 return op->emitError(
1284 "requested position memory size on a singleton level");
1285 }
1286 return success();
1287}
1288
1290 switch (kind) {
1292 return stt.getCrdType();
1294 return stt.getPosType();
1296 return stt.getElementType();
1298 return nullptr;
1299 }
1300 llvm_unreachable("Unrecognizable FieldKind");
1301}
1302
1303static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1304 SparseTensorType stt,
1305 RankedTensorType valTp,
1306 TypeRange lvlTps) {
1307 if (requiresStaticShape && !stt.hasStaticDimShape())
1308 return op->emitError("the sparse-tensor must have static shape");
1309 if (!stt.hasEncoding())
1310 return op->emitError("the sparse-tensor must have an encoding attribute");
1311
1312 // Verifies the trailing COO.
1313 Level cooStartLvl = stt.getAoSCOOStart();
1314 if (cooStartLvl < stt.getLvlRank()) {
1315 // We only supports trailing COO for now, must be the last input.
1316 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1317 // The coordinates should be in shape of <? x rank>
1318 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1319 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1320 return op->emitError("input/output trailing COO level-ranks don't match");
1321 }
1322 }
1323
1324 // Verifies that all types match.
1325 StorageLayout layout(stt.getEncoding());
1326 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1327 return op->emitError("inconsistent number of fields between input/output");
1328
1329 unsigned idx = 0;
1330 bool misMatch = false;
1331 layout.foreachField([&idx, &misMatch, stt, valTp,
1332 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1333 Level lvl, LevelType lt) -> bool {
1335 return true;
1336
1337 Type inputTp = nullptr;
1338 if (fKind == SparseTensorFieldKind::ValMemRef) {
1339 inputTp = valTp;
1340 } else {
1341 assert(fid == idx && stt.getLvlType(lvl) == lt);
1342 inputTp = lvlTps[idx++];
1343 }
1344 // The input element type and expected element type should match.
1345 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1346 Type expElemTp = getFieldElemType(stt, fKind);
1347 if (inpElemTp != expElemTp) {
1348 misMatch = true;
1349 return false; // to terminate the iteration
1350 }
1351 return true;
1352 });
1353
1354 if (misMatch)
1355 return op->emitError("input/output element-types don't match");
1356 return success();
1357}
1358
1359LogicalResult AssembleOp::verify() {
1360 RankedTensorType valuesTp = getValues().getType();
1361 const auto lvlsTp = getLevels().getTypes();
1362 const auto resTp = getSparseTensorType(getResult());
1363 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1364}
1365
1366LogicalResult DisassembleOp::verify() {
1367 if (getOutValues().getType() != getRetValues().getType())
1368 return emitError("output values and return value type mismatch");
1369
1370 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1371 if (ot.getType() != rt.getType())
1372 return emitError("output levels and return levels type mismatch");
1373
1374 RankedTensorType valuesTp = getRetValues().getType();
1375 const auto lvlsTp = getRetLevels().getTypes();
1376 const auto srcTp = getSparseTensorType(getTensor());
1377 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1378}
1379
1380LogicalResult ConvertOp::verify() {
1381 RankedTensorType tp1 = getSource().getType();
1382 RankedTensorType tp2 = getDest().getType();
1383 if (tp1.getRank() != tp2.getRank())
1384 return emitError("unexpected conversion mismatch in rank");
1385 auto dstEnc =
1386 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1387 if (dstEnc && dstEnc.isSlice())
1388 return emitError("cannot convert to a sparse tensor slice");
1389
1390 auto shape1 = tp1.getShape();
1391 auto shape2 = tp2.getShape();
1392 // Accept size matches between the source and the destination type
1393 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1394 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1395 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1396 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1397 return emitError("unexpected conversion mismatch in dimension ") << d;
1398 return success();
1399}
1400
1401OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1402 if (getType() == getSource().getType())
1403 return getSource();
1404 return {};
1405}
1406
1407bool ConvertOp::needsExtraSort() {
1408 SparseTensorType srcStt = getSparseTensorType(getSource());
1409 SparseTensorType dstStt = getSparseTensorType(getDest());
1410
1411 // We do not need an extra sort when returning unordered sparse tensors or
1412 // dense tensor since dense tensor support random access.
1413 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1414 return false;
1415
1416 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1417 srcStt.hasSameDimToLvl(dstStt)) {
1418 return false;
1419 }
1420
1421 // Source and dest tensors are ordered in different ways. We only do direct
1422 // dense to sparse conversion when the dense input is defined by a sparse
1423 // constant. Note that we can theoretically always directly convert from dense
1424 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1425 // performance.
1426 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1427 if (isa<SparseElementsAttr>(constOp.getValue()))
1428 return false;
1429
1430 return true;
1431}
1432
1433LogicalResult CrdTranslateOp::verify() {
1434 uint64_t inRank = getEncoder().getLvlRank();
1435 uint64_t outRank = getEncoder().getDimRank();
1436
1437 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1438 std::swap(inRank, outRank);
1439
1440 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1441 return emitError("Coordinate rank mismatch with encoding");
1442
1443 return success();
1444}
1445
1446LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1447 SmallVectorImpl<OpFoldResult> &results) {
1448 if (getEncoder().isIdentity()) {
1449 results.assign(getInCrds().begin(), getInCrds().end());
1450 return success();
1451 }
1452 if (getEncoder().isPermutation()) {
1453 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1454 ? getEncoder().getDimToLvl()
1455 : getEncoder().getLvlToDim();
1456 for (AffineExpr exp : perm.getResults())
1457 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1458 return success();
1459 }
1460
1461 // Fuse dim2lvl/lvl2dim pairs.
1462 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1463 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1464 return v.getDefiningOp() == def;
1465 });
1466 if (!sameDef)
1467 return failure();
1468
1469 bool oppositeDir = def.getDirection() != getDirection();
1470 bool sameOracle =
1471 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1472 bool sameCount = def.getNumResults() == getInCrds().size();
1473 if (!oppositeDir || !sameOracle || !sameCount)
1474 return failure();
1475
1476 // The definition produces the coordinates in the same order as the input
1477 // coordinates.
1478 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1479 [](auto valuePair) {
1480 auto [lhs, rhs] = valuePair;
1481 return lhs == rhs;
1482 });
1483
1484 if (!sameOrder)
1485 return failure();
1486 // l1 = dim2lvl (lvl2dim l0)
1487 // ==> l0
1488 results.append(def.getInCrds().begin(), def.getInCrds().end());
1489 return success();
1490}
1491
1492void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1493 int64_t index) {
1494 Value val = arith::ConstantIndexOp::create(builder, state.location, index);
1495 return build(builder, state, source, val);
1496}
1497
1498LogicalResult LvlOp::verify() {
1499 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1500 auto stt = getSparseTensorType(getSource());
1501 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1502 return emitError(
1503 "Level index exceeds the rank of the input sparse tensor");
1504 }
1505 return success();
1506}
1507
1508std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1509 return getConstantIntValue(getIndex());
1510}
1511
1512Speculation::Speculatability LvlOp::getSpeculatability() {
1513 auto constantIndex = getConstantLvlIndex();
1514 if (!constantIndex)
1516
1517 assert(constantIndex <
1518 cast<RankedTensorType>(getSource().getType()).getRank());
1520}
1521
1522OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1523 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1524 if (!lvlIndex)
1525 return {};
1526
1527 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1528 auto stt = getSparseTensorType(getSource());
1529 if (lvl >= stt.getLvlRank()) {
1530 // Follows the same convention used by tensor.dim operation. Out of bound
1531 // indices produce undefined behavior but are still valid IR. Don't choke on
1532 // them.
1533 return {};
1534 }
1535
1536 // Helper lambda to build an IndexAttr.
1537 auto getIndexAttr = [this](int64_t lvlSz) {
1538 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1539 };
1540
1541 SmallVector<Size> lvlShape = stt.getLvlShape();
1542 if (ShapedType::isStatic(lvlShape[lvl]))
1543 return getIndexAttr(lvlShape[lvl]);
1544
1545 return {};
1546}
1547
1548void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1549 SparseTensorEncodingAttr dstEnc, Value source) {
1550 auto srcStt = getSparseTensorType(source);
1551 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1552 SmallVector<int64_t> dstDimShape =
1553 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1554 auto dstTp =
1555 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1556 return build(odsBuilder, odsState, dstTp, source);
1557}
1558
1559LogicalResult ReinterpretMapOp::verify() {
1560 auto srcStt = getSparseTensorType(getSource());
1561 auto dstStt = getSparseTensorType(getDest());
1562 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1563 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1564
1565 if (srcLvlTps.size() != dstLvlTps.size())
1566 return emitError("Level rank mismatch between source/dest tensors");
1567
1568 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1569 if (srcLvlTp != dstLvlTp)
1570 return emitError("Level type mismatch between source/dest tensors");
1571
1572 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1573 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1574 return emitError("Crd/Pos width mismatch between source/dest tensors");
1575 }
1576
1577 if (srcStt.getElementType() != dstStt.getElementType())
1578 return emitError("Element type mismatch between source/dest tensors");
1579
1580 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1581 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1582 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1583 if (srcLvlSz != dstLvlSz) {
1584 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1585 // compatible to <3x4>? For now, we require all the level sizes to be
1586 // *exactly* matched for simplicity.
1587 return emitError("Level size mismatch between source/dest tensors");
1588 }
1589 }
1590
1591 return success();
1592}
1593
1594OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1595 if (getSource().getType() == getDest().getType())
1596 return getSource();
1597
1598 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1599 // A -> B, B -> A ==> A
1600 if (def.getSource().getType() == getDest().getType())
1601 return def.getSource();
1602 }
1603 return {};
1604}
1605
1606template <typename ToBufferOp>
1607static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1608 PropertyRef prop, RegionRange region,
1610 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1611 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1612 Type elemTp = nullptr;
1613 bool withStride = false;
1614 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1615 elemTp = stt.getPosType();
1616 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1617 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1618 elemTp = stt.getCrdType();
1619 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1620 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1621 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1622 elemTp = stt.getElementType();
1623 }
1624
1625 assert(elemTp && "unhandled operation.");
1626 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1627 bufShape.push_back(ShapedType::kDynamic);
1628
1629 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1630 stt.getContext(), ShapedType::kDynamic,
1631 {ShapedType::kDynamic})
1632 : StridedLayoutAttr();
1633 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1634 return success();
1635}
1636
1637LogicalResult ToPositionsOp::verify() {
1638 auto stt = getSparseTensorType(getTensor());
1639 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1640 return emitError("requested level is out of bounds");
1641 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1642 return emitError("unexpected type for positions");
1643 return success();
1644}
1645
1646LogicalResult
1647ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1648 ValueRange ops, DictionaryAttr attr,
1649 PropertyRef prop, RegionRange region,
1650 SmallVectorImpl<mlir::Type> &ret) {
1651 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1652}
1653
1654LogicalResult ToCoordinatesOp::verify() {
1655 auto stt = getSparseTensorType(getTensor());
1656 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1657 return emitError("requested level is out of bounds");
1658 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1659 return emitError("unexpected type for coordinates");
1660 return success();
1661}
1662
1663LogicalResult
1664ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1665 ValueRange ops, DictionaryAttr attr,
1666 PropertyRef prop, RegionRange region,
1667 SmallVectorImpl<mlir::Type> &ret) {
1668 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1669}
1670
1671LogicalResult ToCoordinatesBufferOp::verify() {
1672 auto stt = getSparseTensorType(getTensor());
1673 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1674 return emitError("expected sparse tensor with a COO region");
1675 return success();
1676}
1677
1678LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1679 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1680 DictionaryAttr attr, PropertyRef prop, RegionRange region,
1681 SmallVectorImpl<mlir::Type> &ret) {
1682 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1683 ret);
1684}
1685
1686LogicalResult ToValuesOp::verify() {
1687 auto stt = getSparseTensorType(getTensor());
1688 auto mtp = getMemRefType(getResult());
1689 if (stt.getElementType() != mtp.getElementType())
1690 return emitError("unexpected mismatch in element types");
1691 return success();
1692}
1693
1694LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1695 std::optional<Location> loc,
1696 ValueRange ops, DictionaryAttr attr,
1697 PropertyRef prop, RegionRange region,
1698 SmallVectorImpl<mlir::Type> &ret) {
1699 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1700}
1701
1702LogicalResult ToSliceOffsetOp::verify() {
1703 auto rank = getSlice().getType().getRank();
1704 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1705 return emitError("requested dimension out of bound");
1706 return success();
1707}
1708
1709LogicalResult ToSliceStrideOp::verify() {
1710 auto rank = getSlice().getType().getRank();
1711 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1712 return emitError("requested dimension out of bound");
1713 return success();
1714}
1715
1716LogicalResult GetStorageSpecifierOp::verify() {
1717 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1718 getSpecifier(), getOperation());
1719}
1720
1721template <typename SpecifierOp>
1722static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1723 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1724}
1725
1726OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1727 const StorageSpecifierKind kind = getSpecifierKind();
1728 const auto lvl = getLevel();
1729 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1730 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1731 return op.getValue();
1732 return {};
1733}
1734
1735LogicalResult SetStorageSpecifierOp::verify() {
1736 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1737 getSpecifier(), getOperation());
1738}
1739
1740template <class T>
1741static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1742 const char *regionName,
1743 TypeRange inputTypes, Type outputType) {
1744 unsigned numArgs = region.getNumArguments();
1745 unsigned expectedNum = inputTypes.size();
1746 if (numArgs != expectedNum)
1747 return op->emitError() << regionName << " region must have exactly "
1748 << expectedNum << " arguments";
1749
1750 for (unsigned i = 0; i < numArgs; i++) {
1751 Type typ = region.getArgument(i).getType();
1752 if (typ != inputTypes[i])
1753 return op->emitError() << regionName << " region argument " << (i + 1)
1754 << " type mismatch";
1755 }
1756 Block &block = region.front();
1757 if (!block.mightHaveTerminator())
1758 return op->emitError() << regionName
1759 << " region must end with a terminator";
1760
1761 Operation *term = block.getTerminator();
1762 YieldOp yield = dyn_cast<YieldOp>(term);
1763 if (!yield)
1764 return op->emitError() << regionName
1765 << " region must end with sparse_tensor.yield";
1766 if (!yield.hasSingleResult() ||
1767 yield.getSingleResult().getType() != outputType)
1768 return op->emitError() << regionName << " region yield type mismatch";
1769
1770 return success();
1771}
1772
1773LogicalResult BinaryOp::verify() {
1774 NamedAttrList attrs = (*this)->getAttrs();
1775 Type leftType = getX().getType();
1776 Type rightType = getY().getType();
1777 Type outputType = getOutput().getType();
1778 Region &overlap = getOverlapRegion();
1779 Region &left = getLeftRegion();
1780 Region &right = getRightRegion();
1781
1782 // Check correct number of block arguments and return type for each
1783 // non-empty region.
1784 if (!overlap.empty()) {
1785 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1786 TypeRange{leftType, rightType}, outputType)))
1787 return failure();
1788 }
1789 if (!left.empty()) {
1790 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1791 outputType)))
1792 return failure();
1793 } else if (getLeftIdentity()) {
1794 if (leftType != outputType)
1795 return emitError("left=identity requires first argument to have the same "
1796 "type as the output");
1797 }
1798 if (!right.empty()) {
1799 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1800 outputType)))
1801 return failure();
1802 } else if (getRightIdentity()) {
1803 if (rightType != outputType)
1804 return emitError("right=identity requires second argument to have the "
1805 "same type as the output");
1806 }
1807 return success();
1808}
1809
1810LogicalResult UnaryOp::verify() {
1811 Type inputType = getX().getType();
1812 Type outputType = getOutput().getType();
1813
1814 // Check correct number of block arguments and return type for each
1815 // non-empty region.
1816 Region &present = getPresentRegion();
1817 if (!present.empty()) {
1818 if (failed(verifyNumBlockArgs(this, present, "present",
1819 TypeRange{inputType}, outputType)))
1820 return failure();
1821 }
1822 Region &absent = getAbsentRegion();
1823 if (!absent.empty()) {
1824 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1825 outputType)))
1826 return failure();
1827 // Absent branch can only yield invariant values.
1828 Block *absentBlock = &absent.front();
1829 Block *parent = getOperation()->getBlock();
1830 Value absentVal =
1831 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1832 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1833 if (arg.getOwner() == parent)
1834 return emitError("absent region cannot yield linalg argument");
1835 } else if (Operation *def = absentVal.getDefiningOp()) {
1836 if (!isa<arith::ConstantOp>(def) &&
1837 (def->getBlock() == absentBlock || def->getBlock() == parent))
1838 return emitError("absent region cannot yield locally computed value");
1839 }
1840 }
1841 return success();
1842}
1843
1844bool ConcatenateOp::needsExtraSort() {
1845 SparseTensorType dstStt = getSparseTensorType(*this);
1846 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1847 return false;
1848
1849 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1850 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1851 });
1852 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1853 // in all input/output buffers, and all input/output buffers have the same
1854 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1855 // CSC matrices along column).
1856 bool directLowerable =
1857 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1858 return !directLowerable;
1859}
1860
1861LogicalResult ConcatenateOp::verify() {
1862 const auto dstTp = getSparseTensorType(*this);
1863 const Dimension concatDim = getDimension();
1864 const Dimension dimRank = dstTp.getDimRank();
1865
1866 if (getInputs().size() <= 1)
1867 return emitError("Need at least two tensors to concatenate.");
1868
1869 if (concatDim >= dimRank)
1870 return emitError(llvm::formatv(
1871 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1872 concatDim, dimRank));
1873
1874 for (const auto &it : llvm::enumerate(getInputs())) {
1875 const auto i = it.index();
1876 const auto srcTp = getSparseTensorType(it.value());
1877 if (srcTp.hasDynamicDimShape())
1878 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1879 const Dimension srcDimRank = srcTp.getDimRank();
1880 if (srcDimRank != dimRank)
1881 return emitError(
1882 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1883 "from the output tensor (rank={2}).",
1884 i, srcDimRank, dimRank));
1885 }
1886
1887 for (Dimension d = 0; d < dimRank; d++) {
1888 const Size dstSh = dstTp.getDimShape()[d];
1889 if (d == concatDim) {
1890 if (ShapedType::isStatic(dstSh)) {
1891 // If we reach here, then all inputs have static shapes. So we
1892 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1893 // to avoid redundant assertions in the loop.
1894 Size sumSz = 0;
1895 for (const auto src : getInputs())
1896 sumSz += getSparseTensorType(src).getDimShape()[d];
1897 // If all dimension are statically known, the sum of all the input
1898 // dimensions should be equal to the output dimension.
1899 if (sumSz != dstSh)
1900 return emitError(
1901 "The concatenation dimension of the output tensor should be the "
1902 "sum of all the concatenation dimensions of the input tensors.");
1903 }
1904 } else {
1905 Size prev = dstSh;
1906 for (const auto src : getInputs()) {
1907 const auto sh = getSparseTensorType(src).getDimShape()[d];
1908 if (ShapedType::isStatic(prev) && sh != prev)
1909 return emitError("All dimensions (expect for the concatenating one) "
1910 "should be equal.");
1911 prev = sh;
1912 }
1913 }
1914 }
1915
1916 return success();
1917}
1918
1919void PushBackOp::build(OpBuilder &builder, OperationState &result,
1920 Value curSize, Value inBuffer, Value value) {
1921 build(builder, result, curSize, inBuffer, value, Value());
1922}
1923
1924LogicalResult PushBackOp::verify() {
1925 if (Value n = getN()) {
1926 std::optional<int64_t> nValue = getConstantIntValue(n);
1927 if (nValue && nValue.value() < 1)
1928 return emitOpError("n must be not less than 1");
1929 }
1930 return success();
1931}
1932
1933LogicalResult CompressOp::verify() {
1934 const auto stt = getSparseTensorType(getTensor());
1935 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1936 return emitOpError("incorrect number of coordinates");
1937 return success();
1938}
1939
1940void ForeachOp::build(
1941 OpBuilder &builder, OperationState &result, Value tensor,
1942 ValueRange initArgs, AffineMapAttr order,
1943 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1944 bodyBuilder) {
1945 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1946 // Builds foreach body.
1947 if (!bodyBuilder)
1948 return;
1949 const auto stt = getSparseTensorType(tensor);
1950 const Dimension dimRank = stt.getDimRank();
1951
1952 // Starts with `dimRank`-many coordinates.
1953 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1954 // Followed by one value.
1955 blockArgTypes.push_back(stt.getElementType());
1956 // Followed by the reduction variables.
1957 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1958
1959 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1960
1961 OpBuilder::InsertionGuard guard(builder);
1962 auto &region = *result.regions.front();
1963 Block *bodyBlock =
1964 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1965 bodyBuilder(builder, result.location,
1966 bodyBlock->getArguments().slice(0, dimRank),
1967 bodyBlock->getArguments()[dimRank],
1968 bodyBlock->getArguments().drop_front(dimRank + 1));
1969}
1970
1971LogicalResult ForeachOp::verify() {
1972 const auto t = getSparseTensorType(getTensor());
1973 const Dimension dimRank = t.getDimRank();
1974 const auto args = getBody()->getArguments();
1975
1976 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1977 return emitError("Level traverse order does not match tensor's level rank");
1978
1979 if (dimRank + 1 + getInitArgs().size() != args.size())
1980 return emitError("Unmatched number of arguments in the block");
1981
1982 if (getNumResults() != getInitArgs().size())
1983 return emitError("Mismatch in number of init arguments and results");
1984
1985 if (getResultTypes() != getInitArgs().getTypes())
1986 return emitError("Mismatch in types of init arguments and results");
1987
1988 // Cannot mark this const, because the getters aren't.
1989 auto yield = cast<YieldOp>(getBody()->getTerminator());
1990 if (yield.getNumOperands() != getNumResults() ||
1991 yield.getOperands().getTypes() != getResultTypes())
1992 return emitError("Mismatch in types of yield values and results");
1993
1994 const auto iTp = IndexType::get(getContext());
1995 for (Dimension d = 0; d < dimRank; d++)
1996 if (args[d].getType() != iTp)
1997 return emitError(
1998 llvm::formatv("Expecting Index type for argument at index {0}", d));
1999
2000 const auto elemTp = t.getElementType();
2001 const auto valueTp = args[dimRank].getType();
2002 if (elemTp != valueTp)
2003 return emitError(
2004 llvm::formatv("Unmatched element type between input tensor and "
2005 "block argument, expected:{0}, got: {1}",
2006 elemTp, valueTp));
2007 return success();
2008}
2009
2010OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2011 if (getSparseTensorEncoding(getInputCoo().getType()) ==
2012 getSparseTensorEncoding(getResultCoo().getType()))
2013 return getInputCoo();
2014
2015 return {};
2016}
2017
2018LogicalResult ReorderCOOOp::verify() {
2019 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2020 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2021
2022 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2023 return emitError("Expected COO sparse tensors only");
2024
2025 if (!srcStt.hasSameDimToLvl(dstStt))
2026 return emitError("Unmatched dim2lvl map between input and result COO");
2027
2028 if (srcStt.getPosType() != dstStt.getPosType() ||
2029 srcStt.getCrdType() != dstStt.getCrdType() ||
2030 srcStt.getElementType() != dstStt.getElementType())
2031 return emitError("Unmatched storage format between input and result COO");
2032
2033 return success();
2034}
2035
2036LogicalResult ReduceOp::verify() {
2037 Type inputType = getX().getType();
2038 Region &formula = getRegion();
2039 return verifyNumBlockArgs(this, formula, "reduce",
2040 TypeRange{inputType, inputType}, inputType);
2041}
2042
2043LogicalResult SelectOp::verify() {
2044 Builder b(getContext());
2045 Type inputType = getX().getType();
2046 Type boolType = b.getI1Type();
2047 Region &formula = getRegion();
2048 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2049 boolType);
2050}
2051
2052LogicalResult SortOp::verify() {
2053 AffineMap xPerm = getPermMap();
2054 uint64_t nx = xPerm.getNumDims();
2055 if (nx < 1)
2056 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2057
2058 if (!xPerm.isPermutation())
2059 return emitError(
2060 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2061
2062 // We can't check the size of the buffers when n or buffer dimensions aren't
2063 // compile-time constants.
2064 std::optional<int64_t> cn = getConstantIntValue(getN());
2065 if (!cn)
2066 return success();
2067
2068 // Verify dimensions.
2069 const auto checkDim = [&](Value v, Size minSize,
2070 const char *message) -> LogicalResult {
2071 const Size sh = getMemRefType(v).getShape()[0];
2072 if (ShapedType::isStatic(sh) && sh < minSize)
2073 return emitError(
2074 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2075 return success();
2076 };
2077 uint64_t n = cn.value();
2078 uint64_t ny = 0;
2079 if (auto nyAttr = getNyAttr())
2080 ny = nyAttr.getInt();
2081 if (failed(checkDim(getXy(), n * (nx + ny),
2082 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2083 return failure();
2084 for (Value opnd : getYs())
2085 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2086 return failure();
2087
2088 return success();
2089}
2090
2091//===----------------------------------------------------------------------===//
2092// Sparse Tensor Iteration Operations.
2093//===----------------------------------------------------------------------===//
2094
2095IterSpaceType IteratorType::getIterSpaceType() const {
2096 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2097 getHiLvl());
2098}
2099
2100IteratorType IterSpaceType::getIteratorType() const {
2101 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2102}
2103
2104/// Parses a level range in the form "$lo `to` $hi"
2105/// or simply "$lo" if $hi - $lo = 1
2106static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2107 Level &lvlHi) {
2108 if (parser.parseInteger(lvlLo))
2109 return failure();
2110
2111 if (succeeded(parser.parseOptionalKeyword("to"))) {
2112 if (parser.parseInteger(lvlHi))
2113 return failure();
2114 } else {
2115 lvlHi = lvlLo + 1;
2116 }
2117
2118 if (lvlHi <= lvlLo)
2119 return parser.emitError(parser.getNameLoc(),
2120 "expect larger level upper bound than lower bound");
2121
2122 return success();
2123}
2124
2125/// Parses a level range in the form "$lo `to` $hi"
2126/// or simply "$lo" if $hi - $lo = 1
2127static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2128 IntegerAttr &lvlHiAttr) {
2129 Level lvlLo, lvlHi;
2130 if (parseLevelRange(parser, lvlLo, lvlHi))
2131 return failure();
2132
2133 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2134 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2135 return success();
2136}
2137
2138/// Prints a level range in the form "$lo `to` $hi"
2139/// or simply "$lo" if $hi - $lo = 1
2140static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2141
2142 if (lo + 1 == hi)
2143 p << lo;
2144 else
2145 p << lo << " to " << hi;
2146}
2147
2148/// Prints a level range in the form "$lo `to` $hi"
2149/// or simply "$lo" if $hi - $lo = 1
2150static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2151 IntegerAttr lvlHi) {
2152 unsigned lo = lvlLo.getValue().getZExtValue();
2153 unsigned hi = lvlHi.getValue().getZExtValue();
2154 printLevelRange(p, lo, hi);
2155}
2156
2157/// Parses a list of `optional` defined list in the form of
2158/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2159/// corresponding value is not defined (e.g., to represent an undefined
2160/// coordinate in the sparse iteration space).
2161static ParseResult parseOptionalDefinedList(
2162 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2164 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2166 unsigned cnt = 0;
2167 ParseResult crdList =
2168 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2169 if (parser.parseOptionalKeyword("_")) {
2170 if (parser.parseArgument(definedArgs.emplace_back()))
2171 return failure();
2172 definedSet.set(cnt);
2173 }
2174 cnt += 1;
2175 return success();
2176 });
2177
2178 if (cnt > maxCnt)
2179 return parser.emitError(parser.getNameLoc(),
2180 "parsed more value than expected.");
2181
2182 if (failed(crdList)) {
2183 return parser.emitError(
2184 parser.getNameLoc(),
2185 "expecting SSA value or \"_\" for level coordinates");
2186 }
2187 assert(definedArgs.size() == definedSet.count());
2188 return success();
2189}
2190
2191static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2192 Block::BlockArgListType blocksArgs,
2193 I64BitSet definedSet) {
2194 if (definedSet.empty())
2195 return;
2196
2197 for (unsigned i = 0; i < size; i++) {
2198 if (definedSet[i]) {
2199 p << blocksArgs.front();
2200 blocksArgs = blocksArgs.drop_front();
2201 } else {
2202 p << "_";
2203 }
2204 if (i != size - 1)
2205 p << ", ";
2206 }
2207 assert(blocksArgs.empty());
2208}
2209
2210static ParseResult
2213 // Parse "at(%crd0, _, ...)"
2214 I64BitSet crdUsedLvlSet;
2215 if (succeeded(parser.parseOptionalKeyword("at")) &&
2216 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2217 return failure();
2218
2219 // Always use IndexType for the coordinate.
2220 for (auto &coord : coords)
2221 coord.type = parser.getBuilder().getIndexType();
2222
2223 // Set the CrdUsedLvl bitset.
2224 state.addAttribute("crdUsedLvls",
2225 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2226 return success();
2227}
2228
2229static ParseResult
2235
2236 // Parse "%iters, ... in %spaces, ..."
2237 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2238 parser.parseOperandList(spaces))
2239 return failure();
2240
2241 if (iterators.size() != spaces.size())
2242 return parser.emitError(
2243 parser.getNameLoc(),
2244 "mismatch in number of sparse iterators and sparse spaces");
2245
2247 if (failed(parseUsedCoordList(parser, state, coords)))
2248 return failure();
2249 size_t numCrds = coords.size();
2250
2251 // Parse "iter_args(%arg = %init, ...)"
2252 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2253 if (hasIterArgs)
2254 if (parser.parseAssignmentList(blockArgs, initArgs))
2255 return failure();
2256
2257 blockArgs.append(coords);
2258
2259 SmallVector<Type> iterSpaceTps;
2260 // parse ": sparse_tensor.iter_space -> ret"
2261 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2262 return failure();
2263 if (iterSpaceTps.size() != spaces.size())
2264 return parser.emitError(parser.getNameLoc(),
2265 "mismatch in number of iteration space operands "
2266 "and iteration space types");
2267
2268 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2269 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2270 if (!spaceTp)
2271 return parser.emitError(parser.getNameLoc(),
2272 "expected sparse_tensor.iter_space type for "
2273 "iteration space operands");
2274 it.type = spaceTp.getIteratorType();
2275 }
2276
2277 if (hasIterArgs)
2278 if (parser.parseArrowTypeList(state.types))
2279 return failure();
2280
2281 // Resolves input operands.
2282 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2283 state.operands))
2284 return failure();
2285
2286 if (hasIterArgs) {
2287 // Strip off leading args that used for coordinates.
2288 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2289 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2290 return parser.emitError(
2291 parser.getNameLoc(),
2292 "mismatch in number of iteration arguments and return values");
2293 }
2294
2295 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2296 it.type = tp;
2297 if (parser.resolveOperand(init, tp, state.operands))
2298 return failure();
2299 }
2300 }
2301 return success();
2302}
2303
2304static ParseResult
2306 SmallVectorImpl<Value> &spacesVals,
2308
2309 // Parse "(%spaces, ...)"
2312 return failure();
2313
2315 if (failed(parseUsedCoordList(parser, state, coords)))
2316 return failure();
2317 size_t numCrds = coords.size();
2318
2319 // Parse "iter_args(%arg = %init, ...)"
2321 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2322 if (hasIterArgs)
2323 if (parser.parseAssignmentList(blockArgs, initArgs))
2324 return failure();
2325 blockArgs.append(coords);
2326
2327 SmallVector<Type> iterSpaceTps;
2328 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2329 if (parser.parseColon() || parser.parseLParen() ||
2330 parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2331 return failure();
2332
2333 if (iterSpaceTps.size() != spaces.size())
2334 return parser.emitError(parser.getNameLoc(),
2335 "mismatch in number of iteration space operands "
2336 "and iteration space types");
2337
2338 if (hasIterArgs)
2339 if (parser.parseArrowTypeList(state.types))
2340 return failure();
2341
2342 // Resolves input sparse iteration spaces.
2343 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2344 spacesVals))
2345 return failure();
2346 state.operands.append(spacesVals);
2347
2348 if (hasIterArgs) {
2349 // Strip off trailing args that used for coordinates.
2350 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2351 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2352 return parser.emitError(
2353 parser.getNameLoc(),
2354 "mismatch in number of iteration arguments and return values");
2355 }
2356
2357 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2358 it.type = tp;
2359 if (parser.resolveOperand(init, tp, state.operands))
2360 return failure();
2361 }
2362 }
2363 return success();
2364}
2365
2366LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2367 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2368 DictionaryAttr attr, PropertyRef prop, RegionRange region,
2369 SmallVectorImpl<mlir::Type> &ret) {
2370
2371 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2372 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2373 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2374 adaptor.getHiLvl()));
2375 return success();
2376}
2377
2378LogicalResult ExtractIterSpaceOp::verify() {
2379 if (getLoLvl() >= getHiLvl())
2380 return emitOpError("expected smaller level low than level high");
2381
2382 TypedValue<IteratorType> pIter = getParentIter();
2383 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2384 return emitOpError(
2385 "parent iterator should be specified iff level lower bound equals 0");
2386 }
2387
2388 if (pIter) {
2389 IterSpaceType spaceTp = getExtractedSpace().getType();
2390 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2391 return emitOpError(
2392 "mismatch in parent iterator encoding and iteration space encoding.");
2393
2394 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2395 return emitOpError("parent iterator should be used to extract an "
2396 "iteration space from a consecutive level.");
2397 }
2398
2399 return success();
2400}
2401
2402LogicalResult ExtractValOp::verify() {
2403 auto stt = getSparseTensorType(getTensor());
2404 auto itTp = getIterator().getType();
2405
2406 if (stt.getEncoding() != itTp.getEncoding())
2407 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2408
2409 if (stt.getLvlRank() != itTp.getHiLvl())
2410 return emitOpError("must use last-level iterator to extract values. ");
2411
2412 return success();
2413}
2414
2415struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2417
2418 LogicalResult matchAndRewrite(IterateOp iterateOp,
2419 PatternRewriter &rewriter) const override {
2420 I64BitSet newUsedLvls(0);
2421 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2422 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2423 if (auto crd = iterateOp.getLvlCrd(i)) {
2424 if (crd->getUsers().empty())
2425 toRemove.set(crd->getArgNumber());
2426 else
2427 newUsedLvls.set(i);
2428 }
2429 }
2430
2431 // All coordinates are used.
2432 if (toRemove.none())
2433 return failure();
2434
2435 rewriter.startOpModification(iterateOp);
2436 iterateOp.setCrdUsedLvls(newUsedLvls);
2437 iterateOp.getBody()->eraseArguments(toRemove);
2438 rewriter.finalizeOpModification(iterateOp);
2439 return success();
2440 }
2441};
2442
2443void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2444 mlir::MLIRContext *context) {
2445 results.add<RemoveUnusedLvlCrds>(context);
2446}
2447
2448void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2449 Value iterSpace, ValueRange initArgs) {
2450 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2451 // All ones.
2452 I64BitSet set((1 << rank) - 1);
2453 return build(builder, odsState, iterSpace, initArgs, set);
2454}
2455
2456void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2457 Value iterSpace, ValueRange initArgs,
2458 I64BitSet crdUsedLvls) {
2459 OpBuilder::InsertionGuard guard(builder);
2460
2461 odsState.addOperands(iterSpace);
2462 odsState.addOperands(initArgs);
2463 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2464 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2465 Region *bodyRegion = odsState.addRegion();
2466 odsState.addTypes(initArgs.getTypes());
2467 Block *bodyBlock = builder.createBlock(bodyRegion);
2468
2469 // Starts with a list of user-provided loop arguments.
2470 for (Value v : initArgs)
2471 bodyBlock->addArgument(v.getType(), v.getLoc());
2472
2473 // Follows by a list of used coordinates.
2474 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2475 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2476
2477 // Ends with sparse iterator
2478 bodyBlock->addArgument(
2479 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2480 odsState.location);
2481}
2482
2483ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2484 OpAsmParser::Argument iterator;
2485 OpAsmParser::UnresolvedOperand iterSpace;
2486
2487 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2488 if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2489 return failure();
2490 if (iters.size() != 1)
2491 return parser.emitError(parser.getNameLoc(),
2492 "expected only one iterator/iteration space");
2493
2494 iterArgs.append(iters);
2495 Region *body = result.addRegion();
2496 if (parser.parseRegion(*body, iterArgs))
2497 return failure();
2498
2499 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2500
2501 // Parse the optional attribute list.
2502 if (parser.parseOptionalAttrDict(result.attributes))
2503 return failure();
2504
2505 return success();
2506}
2507
2508/// Prints the initialization list in the form of
2509/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2510/// where 'inner' values are assumed to be region arguments and 'outer' values
2511/// are regular SSA values.
2513 Block::BlockArgListType blocksArgs,
2514 ValueRange initializers,
2515 StringRef prefix = "") {
2516 assert(blocksArgs.size() == initializers.size() &&
2517 "expected same length of arguments and initializers");
2518 if (initializers.empty())
2519 return;
2520
2521 p << prefix << '(';
2522 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2523 p << std::get<0>(it) << " = " << std::get<1>(it);
2524 });
2525 p << ")";
2526}
2527
2528template <typename SparseLoopOp>
2529static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2530 if (op.getInitArgs().size() != op.getNumResults()) {
2531 return op.emitOpError(
2532 "mismatch in number of loop-carried values and defined values");
2533 }
2534 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2535 return op.emitOpError("required out-of-bound coordinates");
2536
2537 return success();
2538}
2539
2540LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2541LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2542
2543void IterateOp::print(OpAsmPrinter &p) {
2544 p << " " << getIterator() << " in " << getIterSpace();
2545 if (!getCrdUsedLvls().empty()) {
2546 p << " at(";
2547 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2548 p << ")";
2549 }
2550 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2551
2552 p << " : " << getIterSpace().getType() << " ";
2553 if (!getInitArgs().empty())
2554 p.printArrowTypeList(getInitArgs().getTypes());
2555
2556 p << " ";
2557 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2558 /*printBlockTerminators=*/!getInitArgs().empty());
2559}
2560
2561LogicalResult IterateOp::verifyRegions() {
2562 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2563 return emitOpError("mismatch in iterator and iteration space type");
2564 if (getNumRegionIterArgs() != getNumResults())
2565 return emitOpError(
2566 "mismatch in number of basic block args and defined values");
2567
2568 auto initArgs = getInitArgs();
2569 auto iterArgs = getRegionIterArgs();
2570 auto yieldVals = getYieldedValues();
2571 auto opResults = getResults();
2572 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2573 opResults.size()})) {
2574 return emitOpError() << "number mismatch between iter args and results.";
2575 }
2576
2577 for (auto [i, init, iter, yield, ret] :
2578 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2579 if (init.getType() != ret.getType())
2580 return emitOpError() << "types mismatch between " << i
2581 << "th iter operand and defined value";
2582 if (iter.getType() != ret.getType())
2583 return emitOpError() << "types mismatch between " << i
2584 << "th iter region arg and defined value";
2585 if (yield.getType() != ret.getType())
2586 return emitOpError() << "types mismatch between " << i
2587 << "th yield value and defined value";
2588 }
2589
2590 return success();
2591}
2592
2593/// OpInterfaces' methods implemented by IterateOp.
2594SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2595
2596MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2597 return getInitArgsMutable();
2598}
2599
2600Block::BlockArgListType IterateOp::getRegionIterArgs() {
2601 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2602}
2603
2604std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2605 return cast<sparse_tensor::YieldOp>(
2606 getRegion().getBlocks().front().getTerminator())
2607 .getResultsMutable();
2608}
2609
2610std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2611
2612OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2613 return getInitArgs();
2614}
2615
2616void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2617 SmallVectorImpl<RegionSuccessor> &regions) {
2618 // Both the operation itself and the region may be branching into the body
2619 // or back into the operation itself.
2620 regions.push_back(RegionSuccessor(&getRegion()));
2621 // It is possible for loop not to enter the body.
2622 regions.push_back(RegionSuccessor::parent());
2623}
2624
2625ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2626 return successor.isParent() ? ValueRange(getResults())
2627 : ValueRange(getRegionIterArgs());
2628}
2629
2630void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2631 ValueRange iterSpaces, ValueRange initArgs,
2632 unsigned numCases) {
2633 unsigned rank =
2634 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2635 // All ones.
2636 I64BitSet set((1 << rank) - 1);
2637 // Generates all-zero case bits (they only serve as placeholders), which are
2638 // supposed to be overriden later. We need to preallocate all the regions as
2639 // mlir::Region cannot be dynamically added later after the operation is
2640 // created.
2641 SmallVector<int64_t> caseBits(numCases, 0);
2642 ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2643 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2644 initArgs, set, cases,
2645 /*caseRegionsCount=*/numCases);
2646}
2647
2648ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2649
2650 SmallVector<Value> spaces;
2651 // The block argument list of each regions, it is arranged in the order of
2652 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2653 SmallVector<OpAsmParser::Argument> blockArgs;
2654 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2655 return failure();
2656
2657 result.addAttribute("operandSegmentSizes",
2659 {static_cast<int32_t>(spaces.size()),
2660 static_cast<int32_t>(result.types.size())}));
2661
2662 SmallVector<Attribute> cases;
2663 while (succeeded(parser.parseOptionalKeyword("case"))) {
2664 // Parse one region per case.
2665 I64BitSet definedItSet;
2666 SmallVector<OpAsmParser::Argument> definedIts;
2667 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2668 spaces.size(), OpAsmParser::Delimiter::None))
2669 return failure();
2670
2671 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2672
2673 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2674 // Resolve the iterator type based on the iteration space type.
2675 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2676 definedIts[i].type = spaceTp.getIteratorType();
2677 }
2678 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2679 Region *body = result.addRegion();
2680 if (parser.parseRegion(*body, definedIts))
2681 return failure();
2682
2683 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2684 }
2685
2686 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2687
2688 // Parse the optional attribute list.
2689 if (parser.parseOptionalAttrDict(result.attributes))
2690 return failure();
2691
2692 return success();
2693}
2694
2695void CoIterateOp::print(OpAsmPrinter &p) {
2696 p << " (";
2697 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2698 p << ")";
2699
2700 if (!getCrdUsedLvls().empty()) {
2701 p << " at(";
2702 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2703 p << ")";
2704 }
2705
2706 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2707
2708 p << " : (" << getIterSpaces().getTypes() << ")";
2709 if (!getInitArgs().empty())
2710 p.printArrowTypeList(getInitArgs().getTypes());
2711
2712 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2713 p.printNewline();
2714 p << "case ";
2715 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2716 getRegionDefinedSpace(idx));
2717 p << " ";
2718 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2719 /*printBlockTerminators=*/!getInitArgs().empty());
2720 }
2721}
2722
2723ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2724 return cast<sparse_tensor::YieldOp>(
2725 getRegion(regionIdx).getBlocks().front().getTerminator())
2726 .getResults();
2727}
2728
2729LogicalResult CoIterateOp::verifyRegions() {
2730 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2731 if (getNumRegionIterArgs() != getNumResults())
2732 return emitOpError(
2733 "mismatch in number of basic block args and defined values");
2734
2735 auto initArgs = getInitArgs();
2736 auto iterArgs = getRegionIterArgs(r);
2737 auto yieldVals = getYieldedValues(r);
2738 auto opResults = getResults();
2739 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2740 opResults.size()})) {
2741 return emitOpError()
2742 << "number mismatch between iter args and results on " << r
2743 << "th region";
2744 }
2745
2746 for (auto [i, init, iter, yield, ret] :
2747 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2748 if (init.getType() != ret.getType())
2749 return emitOpError()
2750 << "types mismatch between " << i
2751 << "th iter operand and defined value on " << r << "th region";
2752 if (iter.getType() != ret.getType())
2753 return emitOpError() << "types mismatch between " << i
2754 << "th iter region arg and defined value on " << r
2755 << "th region";
2756 if (yield.getType() != ret.getType())
2757 return emitOpError()
2758 << "types mismatch between " << i
2759 << "th yield value and defined value on " << r << "th region";
2760 }
2761 }
2762
2763 auto cases = getRegionDefinedSpaces();
2764 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2765 if (set.size() != getNumRegions())
2766 return emitOpError("contains duplicated cases.");
2767
2768 return success();
2769}
2770
2771SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2772 SmallVector<Region *> ret;
2773 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2774 for (Region &r : getCaseRegions())
2775 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2776 ret.push_back(&r);
2777
2778 return ret;
2779}
2780
2781//===----------------------------------------------------------------------===//
2782// Sparse Tensor Dialect Setups.
2783//===----------------------------------------------------------------------===//
2784
2785/// Materialize a single constant operation from a given attribute value with
2786/// the desired resultant type.
2787Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2788 Attribute value, Type type,
2789 Location loc) {
2790 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2791 return op;
2792 return nullptr;
2793}
2794
2795void SparseTensorDialect::initialize() {
2796 addAttributes<
2797#define GET_ATTRDEF_LIST
2798#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2799 >();
2800 addTypes<
2801#define GET_TYPEDEF_LIST
2802#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2803 >();
2804 addOperations<
2805#define GET_OP_LIST
2806#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2807 >();
2808 declarePromisedInterfaces<
2809 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2810 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2811 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2812}
2813
2814#define GET_OP_CLASSES
2815#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2816
2817#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, PropertyRef prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition Enums.h:402
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
bool isOrderedLT(LevelType lt)
Definition Enums.h:425
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
Definition Enums.h:442
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition Enums.h:413
uint64_t getM(LevelType lt)
Definition Enums.h:443
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition Enums.h:299