MLIR 23.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
17
22#include "mlir/IR/Builders.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50static llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144 return;
145}
146
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
156
157 const Type specType = StorageSpecifierType::get(stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164
165 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
184 unsigned numFields = 0;
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
194 unsigned numFields = 0; // one value memref
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl) const {
210 unsigned stride = 1;
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(getOffset());
269 os << ", ";
270 os << getStaticString(getSize());
271 os << ", ";
272 os << getStaticString(getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(printer.getStream());
278}
279
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 parser.getCurrentLocation(),
287 "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(parser.parseLParen()) ||
302 failed(parseOptionalStaticSlice(offset, parser)) ||
303 failed(parser.parseComma()) ||
304 failed(parseOptionalStaticSlice(size, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(stride, parser)) ||
307 failed(parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311 offset, size, stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(0, 0);
354}
355
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
362}
363
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365 return withExplicitVal(Attribute());
366}
367
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
374}
375
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377 return withImplicitVal(Attribute());
378}
379
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385}
386
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389}
390
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
395}
396
397bool SparseTensorEncodingAttr::isAllDense() const {
398 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399}
400
401bool SparseTensorEncodingAttr::isAllOrdered() const {
402 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403}
404
405Type SparseTensorEncodingAttr::getCrdElemType() const {
406 if (!getImpl())
407 return nullptr;
408 if (getCrdWidth())
409 return IntegerType::get(getContext(), getCrdWidth());
410 return IndexType::get(getContext());
411}
412
413Type SparseTensorEncodingAttr::getPosElemType() const {
414 if (!getImpl())
415 return nullptr;
416 if (getPosWidth())
417 return IntegerType::get(getContext(), getPosWidth());
418 return IndexType::get(getContext());
419}
420
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape) const {
423 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424 return MemRefType::get(shape, getCrdElemType());
425}
426
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape) const {
429 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430 return MemRefType::get(shape, getPosElemType());
431}
432
433bool SparseTensorEncodingAttr::isIdentity() const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435}
436
437bool SparseTensorEncodingAttr::isPermutation() const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439}
440
441Dimension SparseTensorEncodingAttr::getDimRank() const {
442 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445}
446
447Level SparseTensorEncodingAttr::getLvlRank() const {
448 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
450}
451
452LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453 if (!getImpl())
454 return LevelFormat::Batch;
455 assert(l < getLvlRank() && "Level is out of bounds");
456 return getLvlTypes()[l];
457}
458
459bool SparseTensorEncodingAttr::isSlice() const {
460 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
462}
463
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466 assert(isSlice() && "Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() && "Dimension is out of bounds");
469 return dimSlices[dim];
470}
471
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474 return getDimSlice(dim).getStaticOffset();
475}
476
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479 return getDimSlice(dim).getStaticStride();
480}
481
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484 return getStaticDimSliceOffset(toDim(*this, lvl));
485}
486
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489 return getStaticDimSliceStride(toDim(*this, lvl));
490}
491
492SmallVector<int64_t>
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir) const {
495 if (isIdentity())
496 return SmallVector<int64_t>(srcShape);
497
498 SmallVector<int64_t> ret;
499 unsigned rank =
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501 ret.reserve(rank);
502
503 if (isPermutation()) {
504 for (unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506 : toLvl(*this, r);
507 ret.push_back(srcShape[trans]);
508 }
509 return ret;
510 }
511
512 // Handle non-permutation maps.
513 AffineMap transMap =
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
520 // Push back the max coordinate for the given dimension/level size.
521 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
522 } else {
523 // A dynamic size, use a AffineDimExpr to symbolize the value.
524 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
525 }
526 };
527
528 for (AffineExpr exp : transMap.getResults()) {
529 // Do constant propagation on the affine map.
530 AffineExpr evalExp =
531 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
532 // use llvm namespace here to avoid ambiguity
533 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534 ret.push_back(c.getValue() + 1);
535 } else {
536 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
537 mod && mod.getKind() == AffineExprKind::Mod) {
538 // We can still infer a static bound for expressions in form
539 // "d % constant" since d % constant \in [0, constant).
540 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541 ret.push_back(bound.getValue());
542 continue;
543 }
544 }
545 ret.push_back(ShapedType::kDynamic);
546 }
547 }
548 assert(ret.size() == rank);
549 return ret;
550}
551
553SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
554 ValueRange crds,
555 CrdTransDirectionKind dir) const {
556 if (!getImpl())
557 return crds;
558
559 SmallVector<Type> retType(
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
561 builder.getIndexType());
562 auto transOp =
563 CrdTranslateOp::create(builder, loc, retType, crds, dir, *this);
564 return transOp.getOutCrds();
565}
566
567Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
568 // Open "<{" part.
569 if (failed(parser.parseLess()))
570 return {};
571 if (failed(parser.parseLBrace()))
572 return {};
573
574 // Process the data from the parsed dictionary value into struct-like data.
575 SmallVector<LevelType> lvlTypes;
576 SmallVector<SparseTensorDimSliceAttr> dimSlices;
577 AffineMap dimToLvl = {};
578 AffineMap lvlToDim = {};
579 unsigned posWidth = 0;
580 unsigned crdWidth = 0;
581 Attribute explicitVal;
582 Attribute implicitVal;
583 StringRef attrName;
584 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
585 "explicitVal", "implicitVal"};
586 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
587 // Detect admissible keyword.
588 auto *it = find(keys, attrName);
589 if (it == keys.end()) {
590 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
591 return {};
592 }
593 unsigned keyWordIndex = it - keys.begin();
594 // Consume the `=` after keys
595 if (failed(parser.parseEqual()))
596 return {};
597 // Dispatch on keyword.
598 switch (keyWordIndex) {
599 case 0: { // map
600 ir_detail::DimLvlMapParser cParser(parser);
601 auto res = cParser.parseDimLvlMap();
602 if (failed(res))
603 return {};
604 const auto &dlm = *res;
605
606 const Level lvlRank = dlm.getLvlRank();
607 for (Level lvl = 0; lvl < lvlRank; lvl++)
608 lvlTypes.push_back(dlm.getLvlType(lvl));
609
610 const Dimension dimRank = dlm.getDimRank();
611 for (Dimension dim = 0; dim < dimRank; dim++)
612 dimSlices.push_back(dlm.getDimSlice(dim));
613 // NOTE: the old syntax requires an all-or-nothing approach to
614 // `dimSlices`; therefore, if any slice actually exists then we need
615 // to convert null-DSA into default/nop DSA.
616 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617 return static_cast<bool>(slice.getImpl());
618 };
619 if (llvm::any_of(dimSlices, isDefined)) {
620 const auto defaultSlice =
621 SparseTensorDimSliceAttr::get(parser.getContext());
622 for (Dimension dim = 0; dim < dimRank; dim++)
623 if (!isDefined(dimSlices[dim]))
624 dimSlices[dim] = defaultSlice;
625 } else {
626 dimSlices.clear();
627 }
628
629 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
630 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
631 break;
632 }
633 case 1: { // posWidth
634 Attribute attr;
635 if (failed(parser.parseAttribute(attr)))
636 return {};
637 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
638 if (!intAttr) {
639 parser.emitError(parser.getNameLoc(),
640 "expected an integral position bitwidth");
641 return {};
642 }
643 posWidth = intAttr.getInt();
644 break;
645 }
646 case 2: { // crdWidth
647 Attribute attr;
648 if (failed(parser.parseAttribute(attr)))
649 return {};
650 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
651 if (!intAttr) {
652 parser.emitError(parser.getNameLoc(),
653 "expected an integral index bitwidth");
654 return {};
655 }
656 crdWidth = intAttr.getInt();
657 break;
658 }
659 case 3: { // explicitVal
660 Attribute attr;
661 if (failed(parser.parseAttribute(attr)))
662 return {};
663 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
664 explicitVal = result;
665 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
666 explicitVal = result;
667 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668 explicitVal = result;
669 } else {
670 parser.emitError(parser.getNameLoc(),
671 "expected a numeric value for explicitVal");
672 return {};
673 }
674 break;
675 }
676 case 4: { // implicitVal
677 Attribute attr;
678 if (failed(parser.parseAttribute(attr)))
679 return {};
680 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
681 implicitVal = result;
682 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
683 implicitVal = result;
684 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685 implicitVal = result;
686 } else {
687 parser.emitError(parser.getNameLoc(),
688 "expected a numeric value for implicitVal");
689 return {};
690 }
691 break;
692 }
693 } // switch
694 // Only last item can omit the comma.
695 if (parser.parseOptionalComma().failed())
696 break;
697 }
698
699 // Close "}>" part.
700 if (failed(parser.parseRBrace()))
701 return {};
702 if (failed(parser.parseGreater()))
703 return {};
704
705 // Construct struct-like storage for attribute.
706 if (!lvlToDim || lvlToDim.isEmpty()) {
707 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
708 }
709 return parser.getChecked<SparseTensorEncodingAttr>(
710 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711 explicitVal, implicitVal, dimSlices);
712}
713
714void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
715 auto map = static_cast<AffineMap>(getDimToLvl());
716 // Empty affine map indicates identity map
717 if (!map)
718 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
719 printer << "<{ map = ";
720 printSymbols(map, printer);
721 printer << '(';
722 printDimensions(map, printer, getDimSlices());
723 printer << ") -> (";
724 printLevels(map, printer, getLvlTypes());
725 printer << ')';
726 // Print remaining members only for non-default values.
727 if (getPosWidth())
728 printer << ", posWidth = " << getPosWidth();
729 if (getCrdWidth())
730 printer << ", crdWidth = " << getCrdWidth();
731 if (getExplicitVal()) {
732 printer << ", explicitVal = " << getExplicitVal();
733 }
734 if (getImplicitVal())
735 printer << ", implicitVal = " << getImplicitVal();
736 printer << " }>";
737}
738
739void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
740 AsmPrinter &printer) const {
741 if (map.getNumSymbols() == 0)
742 return;
743 printer << '[';
744 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
745 printer << 's' << i << ", ";
746 if (map.getNumSymbols() >= 1)
747 printer << 's' << map.getNumSymbols() - 1;
748 printer << ']';
749}
750
751void SparseTensorEncodingAttr::printDimensions(
752 AffineMap &map, AsmPrinter &printer,
753 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
754 if (!dimSlices.empty()) {
755 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
756 printer << 'd' << i << " : " << dimSlices[i] << ", ";
757 if (map.getNumDims() >= 1) {
758 printer << 'd' << map.getNumDims() - 1 << " : "
759 << dimSlices[map.getNumDims() - 1];
760 }
761 } else {
762 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
763 printer << 'd' << i << ", ";
764 if (map.getNumDims() >= 1)
765 printer << 'd' << map.getNumDims() - 1;
766 }
767}
768
769void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
770 ArrayRef<LevelType> lvlTypes) const {
771 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
772 map.getResult(i).print(printer.getStream());
773 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
774 }
775 if (map.getNumResults() >= 1) {
776 auto lastIndex = map.getNumResults() - 1;
777 map.getResult(lastIndex).print(printer.getStream());
778 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
779 }
780}
781
782LogicalResult SparseTensorEncodingAttr::verify(
783 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
784 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
785 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
786 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
787 if (!acceptBitWidth(posWidth))
788 return emitError() << "unexpected position bitwidth: " << posWidth;
789 if (!acceptBitWidth(crdWidth))
790 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
791
792 // Verify every COO segment.
793 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
794 while (it != lvlTypes.end()) {
795 if (it == lvlTypes.begin() ||
797 return emitError() << "expected compressed or loose_compressed level "
798 "before singleton level";
799
800 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
801 if (!std::all_of(it, curCOOEnd, isSingletonLT))
802 return emitError() << "expected all singleton lvlTypes "
803 "following a singleton level";
804 // We can potentially support mixed SoA/AoS singleton levels.
805 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
806 return it->isa<LevelPropNonDefault::SoA>() ==
808 })) {
809 return emitError() << "expected all singleton lvlTypes stored in the "
810 "same memory layout (SoA vs AoS).";
811 }
812 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
813 }
814
815 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
816 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
817 return emitError() << "Batch lvlType can only be leading levels.";
818
819 // SoA property can only be applied on singleton level.
820 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
821 return lt.isa<LevelPropNonDefault::SoA>();
822 });
823 if (llvm::any_of(soaLvls, [](LevelType lt) {
824 return !lt.isa<LevelFormat::Singleton>();
825 })) {
826 return emitError() << "SoA is only applicable to singleton lvlTypes.";
827 }
828
829 // Dense levels cannot follow a non-unique level. The iteration model for
830 // dense levels requires exactly one parent position to linearize into a
831 // contiguous range, but a non-unique parent provides two cursor values
832 // (segment start and end), which the dense level cannot handle.
833 for (auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
834 if (isDenseLT(lt) && !isUniqueLT(lvlTypes[i - 1]))
835 return emitError() << "dense level cannot follow a non-unique level";
836 }
837
838 // TODO: audit formats that actually are supported by backend.
839 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
840 it != std::end(lvlTypes)) {
841 if (it != lvlTypes.end() - 1)
842 return emitError() << "expected n_out_of_m to be the last level type";
843 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
844 return emitError() << "expected all dense lvlTypes "
845 "before a n_out_of_m level";
846 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
847 if (!isBlockSparsity(dimToLvl)) {
848 return emitError()
849 << "expected 1xm block structure for n_out_of_m level";
850 }
851 auto sizes = getBlockSize(dimToLvl);
852 unsigned coefficient = 0;
853 for (const auto &elem : sizes) {
854 if (elem != 0) {
855 if (elem != coefficient && coefficient != 0) {
856 return emitError() << "expected only one blocked level "
857 "with the same coefficients";
858 }
859 coefficient = elem;
860 }
861 }
862 if (coefficient != getM(*it)) {
863 return emitError() << "expected coeffiencts of Affine expressions "
864 "to be equal to m of n_out_of_m level";
865 }
866 }
867 }
868 // Before we can check that the level-rank is consistent/coherent
869 // across all fields, we need to define it. The source-of-truth for
870 // the `getLvlRank` method is the length of the level-types array,
871 // since it must always be provided and have full rank; therefore we
872 // use that same source-of-truth here.
873 const Level lvlRank = lvlTypes.size();
874 if (lvlRank == 0)
875 return emitError() << "expected a non-empty array for lvlTypes";
876 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
877 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
878 if (dimToLvl) {
879 if (dimToLvl.getNumResults() != lvlRank)
880 return emitError()
881 << "level-rank mismatch between dimToLvl and lvlTypes: "
882 << dimToLvl.getNumResults() << " != " << lvlRank;
883 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
884 // Symbols can't be inferred but are acceptable.
885 if (!inferRes && dimToLvl.getNumSymbols() == 0)
886 return emitError() << "failed to infer lvlToDim from dimToLvl";
887 if (lvlToDim && (inferRes != lvlToDim))
888 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
889 if (dimRank > lvlRank)
890 return emitError() << "unexpected dimToLvl mapping from " << dimRank
891 << " to " << lvlRank;
892 }
893 if (!dimSlices.empty()) {
894 if (dimSlices.size() != dimRank)
895 return emitError()
896 << "dimension-rank mismatch between dimSlices and dimToLvl: "
897 << dimSlices.size() << " != " << dimRank;
898 // Compiler support for `dimSlices` currently requires that the two
899 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
900 if (dimRank != lvlRank)
901 return emitError()
902 << "dimSlices expected dimension-rank to match level-rank: "
903 << dimRank << " != " << lvlRank;
904 }
905 return success();
906}
907
908LogicalResult SparseTensorEncodingAttr::verifyEncoding(
909 ArrayRef<Size> dimShape, Type elementType,
910 function_ref<InFlightDiagnostic()> emitError) const {
911 // Check structural integrity. In particular, this ensures that the
912 // level-rank is coherent across all the fields.
913 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
914 getPosWidth(), getCrdWidth(), getExplicitVal(),
915 getImplicitVal(), getDimSlices())))
916 return failure();
917 // Check integrity with tensor type specifics. In particular, we
918 // need only check that the dimension-rank of the tensor agrees with
919 // the dimension-rank of the encoding.
920 const Dimension dimRank = dimShape.size();
921 if (dimRank == 0)
922 return emitError() << "expected non-scalar sparse tensor";
923 if (getDimRank() != dimRank)
924 return emitError()
925 << "dimension-rank mismatch between encoding and tensor shape: "
926 << getDimRank() << " != " << dimRank;
927 if (auto expVal = getExplicitVal()) {
928 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
929 if (attrType != elementType) {
930 return emitError() << "explicit value type mismatch between encoding and "
931 << "tensor element type: " << attrType
932 << " != " << elementType;
933 }
934 }
935 if (auto impVal = getImplicitVal()) {
936 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
937 if (attrType != elementType) {
938 return emitError() << "implicit value type mismatch between encoding and "
939 << "tensor element type: " << attrType
940 << " != " << elementType;
941 }
942 // Currently, we only support zero as the implicit value.
943 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
944 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
945 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
946 if ((impFVal && impFVal.getValue().isNonZero()) ||
947 (impIntVal && !impIntVal.getValue().isZero()) ||
948 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
949 impComplexVal.getReal().isNonZero()))) {
950 return emitError() << "implicit value must be zero";
951 }
952 }
953 return success();
954}
955
956Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
957 SmallVector<COOSegment> coo = getCOOSegments();
958 assert(coo.size() == 1 || coo.empty());
959 if (!coo.empty() && coo.front().isAoS()) {
960 return coo.front().lvlRange.first;
961 }
962 return getLvlRank();
963}
964
965SmallVector<COOSegment>
966mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
967 SmallVector<COOSegment> ret;
968 if (getLvlRank() <= 1)
969 return ret;
970
971 ArrayRef<LevelType> lts = getLvlTypes();
972 Level l = 0;
973 while (l < getLvlRank()) {
974 auto lt = lts[l];
976 auto cur = lts.begin() + l;
977 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
978 return !lt.isa<LevelFormat::Singleton>();
979 });
980 unsigned cooLen = std::distance(cur, end);
981 if (cooLen > 1) {
982 // To support mixed SoA/AoS COO, we should break the segment when the
983 // storage scheme changes, for now we faithfully assume that all
984 // consecutive singleton levels have the same storage format as verified
985 // STEA.
986 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
987 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
988 }
989 l += cooLen;
990 } else {
991 l++;
992 }
993 }
994 return ret;
995}
996
997//===----------------------------------------------------------------------===//
998// SparseTensorType Methods.
999//===----------------------------------------------------------------------===//
1000
1002 bool isUnique) const {
1003 if (!hasEncoding())
1004 return false;
1005 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1006 return false;
1007 for (Level l = startLvl + 1; l < lvlRank; ++l)
1008 if (!isSingletonLvl(l))
1009 return false;
1010 // If isUnique is true, then make sure that the last level is unique,
1011 // that is, when lvlRank == 1, the only compressed level is unique,
1012 // and when lvlRank > 1, the last singleton is unique.
1013 return !isUnique || isUniqueLvl(lvlRank - 1);
1014}
1015
1016RankedTensorType
1018 SmallVector<LevelType> lvlTypes;
1019 lvlTypes.reserve(lvlRank);
1020 // A non-unique compressed level at beginning (unless this is
1021 // also the last level, then it is unique).
1022 lvlTypes.push_back(
1023 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1024 if (lvlRank > 1) {
1025 // Followed by n-2 non-unique singleton levels.
1026 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1027 *buildLevelType(LevelFormat::Singleton, ordered, false));
1028 // Ends by a unique singleton level.
1029 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1030 }
1031 auto enc = SparseTensorEncodingAttr::get(
1032 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1034 return RankedTensorType::get(getDimShape(), getElementType(), enc);
1035}
1036
1037//===----------------------------------------------------------------------===//
1038// Convenience Methods.
1039//===----------------------------------------------------------------------===//
1040
1041SparseTensorEncodingAttr
1043 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1044 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1045 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1046 return mdtp.getEncoding();
1047 return nullptr;
1048}
1049
1051 MLIRContext *context) {
1052 auto map = static_cast<AffineMap>(dimToLvl);
1053 AffineMap lvlToDim;
1054 // Return an empty lvlToDim when inference is not successful.
1055 if (!map || map.getNumSymbols() != 0) {
1056 lvlToDim = AffineMap();
1057 } else if (map.isPermutation()) {
1058 lvlToDim = inversePermutation(map);
1059 } else if (isBlockSparsity(map)) {
1060 lvlToDim = inverseBlockSparsity(map, context);
1061 }
1062 return lvlToDim;
1063}
1064
1066 MLIRContext *context) {
1067 SmallVector<AffineExpr> lvlExprs;
1068 auto numLvls = dimToLvl.getNumResults();
1069 lvlExprs.reserve(numLvls);
1070 // lvlExprComponents stores information of the floordiv and mod operations
1071 // applied to the same dimension, so as to build the lvlToDim map.
1072 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1073 for (unsigned i = 0, n = numLvls; i < n; i++) {
1074 auto result = dimToLvl.getResult(i);
1075 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1076 if (result.getKind() == AffineExprKind::FloorDiv) {
1077 // Position of the dimension in dimToLvl.
1078 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1079 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1080 "expected only one floordiv for each dimension");
1081 SmallVector<AffineExpr, 3> components;
1082 // Level variable for floordiv.
1083 components.push_back(getAffineDimExpr(i, context));
1084 // Multiplier.
1085 components.push_back(binOp.getRHS());
1086 // Map key is the position of the dimension.
1087 lvlExprComponents[pos] = components;
1088 } else if (result.getKind() == AffineExprKind::Mod) {
1089 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1090 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1091 "expected floordiv before mod");
1092 // Add level variable for mod to the same vector
1093 // of the corresponding floordiv.
1094 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1095 } else {
1096 assert(false && "expected floordiv or mod");
1097 }
1098 } else {
1099 lvlExprs.push_back(getAffineDimExpr(i, context));
1100 }
1101 }
1102 // Build lvlExprs from lvlExprComponents.
1103 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1104 // would be [il, 2, ii]. It could be used to build the AffineExpr
1105 // i = il * 2 + ii in lvlToDim.
1106 for (auto &components : lvlExprComponents) {
1107 assert(components.second.size() == 3 &&
1108 "expected 3 components to build lvlExprs");
1109 auto mulOp = getAffineBinaryOpExpr(
1110 AffineExprKind::Mul, components.second[0], components.second[1]);
1111 auto addOp =
1112 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1113 lvlExprs.push_back(addOp);
1114 }
1115 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1116}
1117
1119 assert(isBlockSparsity(dimToLvl) &&
1120 "expected dimToLvl to be block sparsity for calling getBlockSize");
1121 SmallVector<unsigned> blockSize;
1122 for (auto result : dimToLvl.getResults()) {
1123 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1124 if (result.getKind() == AffineExprKind::Mod) {
1125 blockSize.push_back(
1126 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1127 }
1128 } else {
1129 blockSize.push_back(0);
1130 }
1131 }
1132 return blockSize;
1133}
1134
1136 if (!dimToLvl)
1137 return false;
1138 std::map<unsigned, int64_t> coeffientMap;
1139 bool hasBlock = false;
1140 for (auto result : dimToLvl.getResults()) {
1141 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1142 // Check for "dim op const".
1143 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1144 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1145 if (!dimOp || !conOp || conOp.getValue() <= 0)
1146 return false;
1147 // Inspect "dim / const" or "dim % const".
1148 auto pos = dimOp.getPosition();
1149 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1150 // Expect only one floordiv for each dimension.
1151 auto [it, inserted] = coeffientMap.try_emplace(pos);
1152 if (!inserted)
1153 return false;
1154 // Record coefficient of the floordiv.
1155 it->second = conOp.getValue();
1156 } else if (binOp.getKind() == AffineExprKind::Mod) {
1157 // Expect floordiv before mod.
1158 auto it = coeffientMap.find(pos);
1159 if (it == coeffientMap.end())
1160 return false;
1161 // Expect mod to have the same coefficient as floordiv.
1162 if (conOp.getValue() != it->second)
1163 return false;
1164 hasBlock = true;
1165 } else {
1166 return false;
1167 }
1168 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1169 auto pos = dimOp.getPosition();
1170 // Expect dim to be unset.
1171 if (!coeffientMap.try_emplace(pos, 0).second)
1172 return false;
1173 } else {
1174 return false;
1175 }
1176 }
1177 return hasBlock;
1178}
1179
1181 auto hasNonIdentityMap = [](Value v) {
1182 auto stt = tryGetSparseTensorType(v);
1183 return stt && !stt->isIdentity();
1184 };
1185
1186 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1187 llvm::any_of(op->getResults(), hasNonIdentityMap);
1188}
1189
1190Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1191 if (enc) {
1192 assert(enc.isPermutation() && "Non permutation map not supported");
1193 if (const auto dimToLvl = enc.getDimToLvl())
1194 return dimToLvl.getDimPosition(l);
1195 }
1196 return l;
1197}
1198
1199Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1200 if (enc) {
1201 assert(enc.isPermutation() && "Non permutation map not supported");
1202 if (const auto lvlToDim = enc.getLvlToDim())
1203 return lvlToDim.getDimPosition(d);
1204 }
1205 return d;
1206}
1207
1208/// We normalized sparse tensor encoding attribute by always using
1209/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1210/// as other variants) lead to the same storage specifier type, and stripping
1211/// irrelevant fields that do not alter the sparse tensor memory layout.
1212static SparseTensorEncodingAttr
1213getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1215 for (auto lt : enc.getLvlTypes())
1216 lts.push_back(lt.stripStorageIrrelevantProperties());
1217
1218 return SparseTensorEncodingAttr::get(
1219 enc.getContext(), lts,
1220 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1221 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1222 // Always use `index` for memSize and lvlSize instead of reusing
1223 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1224 // value for different bitwidth, it also avoids casting between index and
1225 // integer (returned by DimOp)
1226 0, 0,
1227 Attribute(), // explicitVal (irrelevant to storage specifier)
1228 Attribute(), // implicitVal (irrelevant to storage specifier)
1229 enc.getDimSlices());
1230}
1231
1232StorageSpecifierType
1233StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1234 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1235}
1236
1237StorageSpecifierType
1238StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1239 MLIRContext *ctx,
1240 SparseTensorEncodingAttr encoding) {
1241 return Base::getChecked(emitError, ctx,
1243}
1244
1245//===----------------------------------------------------------------------===//
1246// SparseTensorDialect Operations.
1247//===----------------------------------------------------------------------===//
1248
1249static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1250 return success(lvl < getSparseTensorType(tensor).getLvlRank());
1251}
1252
1253static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1254 const Type etp = getMemRefType(mem).getElementType();
1255 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1256}
1257
1258static LogicalResult verifySparsifierGetterSetter(
1259 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1261 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1262 return op->emitError(
1263 "redundant level argument for querying value memory size");
1264 }
1265
1266 const auto enc = md.getType().getEncoding();
1267 const Level lvlRank = enc.getLvlRank();
1268
1269 if (mdKind == StorageSpecifierKind::DimOffset ||
1270 mdKind == StorageSpecifierKind::DimStride)
1271 if (!enc.isSlice())
1272 return op->emitError("requested slice data on non-slice tensor");
1273
1274 if (mdKind != StorageSpecifierKind::ValMemSize) {
1275 if (!lvl)
1276 return op->emitError("missing level argument");
1277
1278 const Level l = lvl.value();
1279 if (l >= lvlRank)
1280 return op->emitError("requested level is out of bounds");
1281
1282 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1283 return op->emitError(
1284 "requested position memory size on a singleton level");
1285 }
1286 return success();
1287}
1288
1290 switch (kind) {
1292 return stt.getCrdType();
1294 return stt.getPosType();
1296 return stt.getElementType();
1298 return nullptr;
1299 }
1300 llvm_unreachable("Unrecognizable FieldKind");
1301}
1302
1303static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1304 SparseTensorType stt,
1305 RankedTensorType valTp,
1306 TypeRange lvlTps) {
1307 if (requiresStaticShape && !stt.hasStaticDimShape())
1308 return op->emitError("the sparse-tensor must have static shape");
1309 if (!stt.hasEncoding())
1310 return op->emitError("the sparse-tensor must have an encoding attribute");
1311
1312 // Verifies the trailing COO.
1313 Level cooStartLvl = stt.getAoSCOOStart();
1314 if (cooStartLvl < stt.getLvlRank()) {
1315 // We only supports trailing COO for now, must be the last input.
1316 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1317 // The coordinates should be in shape of <? x rank>
1318 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1319 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1320 return op->emitError("input/output trailing COO level-ranks don't match");
1321 }
1322 }
1323
1324 // Verifies that all types match.
1325 StorageLayout layout(stt.getEncoding());
1326 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1327 return op->emitError("inconsistent number of fields between input/output");
1328
1329 unsigned idx = 0;
1330 bool misMatch = false;
1331 layout.foreachField([&idx, &misMatch, stt, valTp,
1332 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1333 Level lvl, LevelType lt) -> bool {
1335 return true;
1336
1337 Type inputTp = nullptr;
1338 if (fKind == SparseTensorFieldKind::ValMemRef) {
1339 inputTp = valTp;
1340 } else {
1341 assert(fid == idx && stt.getLvlType(lvl) == lt);
1342 inputTp = lvlTps[idx++];
1343 }
1344 // The input element type and expected element type should match.
1345 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1346 Type expElemTp = getFieldElemType(stt, fKind);
1347 if (inpElemTp != expElemTp) {
1348 misMatch = true;
1349 return false; // to terminate the iteration
1350 }
1351 return true;
1352 });
1353
1354 if (misMatch)
1355 return op->emitError("input/output element-types don't match");
1356 return success();
1357}
1358
1359LogicalResult AssembleOp::verify() {
1360 RankedTensorType valuesTp = getValues().getType();
1361 const auto lvlsTp = getLevels().getTypes();
1362 const auto resTp = getSparseTensorType(getResult());
1363 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1364}
1365
1366LogicalResult DisassembleOp::verify() {
1367 if (getOutValues().getType() != getRetValues().getType())
1368 return emitError("output values and return value type mismatch");
1369
1370 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1371 if (ot.getType() != rt.getType())
1372 return emitError("output levels and return levels type mismatch");
1373
1374 RankedTensorType valuesTp = getRetValues().getType();
1375 const auto lvlsTp = getRetLevels().getTypes();
1376 const auto srcTp = getSparseTensorType(getTensor());
1377 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1378}
1379
1380LogicalResult ConvertOp::verify() {
1381 RankedTensorType tp1 = getSource().getType();
1382 RankedTensorType tp2 = getDest().getType();
1383 if (tp1.getRank() != tp2.getRank())
1384 return emitError("unexpected conversion mismatch in rank");
1385 auto dstEnc =
1386 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1387 if (dstEnc && dstEnc.isSlice())
1388 return emitError("cannot convert to a sparse tensor slice");
1389
1390 auto shape1 = tp1.getShape();
1391 auto shape2 = tp2.getShape();
1392 // Accept size matches between the source and the destination type
1393 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1394 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1395 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1396 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1397 return emitError("unexpected conversion mismatch in dimension ") << d;
1398 return success();
1399}
1400
1401OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1402 if (getType() == getSource().getType())
1403 return getSource();
1404 return {};
1405}
1406
1407bool ConvertOp::needsExtraSort() {
1408 SparseTensorType srcStt = getSparseTensorType(getSource());
1409 SparseTensorType dstStt = getSparseTensorType(getDest());
1410
1411 // We do not need an extra sort when returning unordered sparse tensors or
1412 // dense tensor since dense tensor support random access.
1413 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1414 return false;
1415
1416 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1417 srcStt.hasSameDimToLvl(dstStt)) {
1418 return false;
1419 }
1420
1421 // Source and dest tensors are ordered in different ways. We only do direct
1422 // dense to sparse conversion when the dense input is defined by a sparse
1423 // constant. Note that we can theoretically always directly convert from dense
1424 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1425 // performance.
1426 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1427 if (isa<SparseElementsAttr>(constOp.getValue()))
1428 return false;
1429
1430 return true;
1431}
1432
1433LogicalResult CrdTranslateOp::verify() {
1434 uint64_t inRank = getEncoder().getLvlRank();
1435 uint64_t outRank = getEncoder().getDimRank();
1436
1437 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1438 std::swap(inRank, outRank);
1439
1440 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1441 return emitError("Coordinate rank mismatch with encoding");
1442
1443 return success();
1444}
1445
1446LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1447 SmallVectorImpl<OpFoldResult> &results) {
1448 if (getEncoder().isIdentity()) {
1449 results.assign(getInCrds().begin(), getInCrds().end());
1450 return success();
1451 }
1452 if (getEncoder().isPermutation()) {
1453 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1454 ? getEncoder().getDimToLvl()
1455 : getEncoder().getLvlToDim();
1456 for (AffineExpr exp : perm.getResults())
1457 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1458 return success();
1459 }
1460
1461 // Fuse dim2lvl/lvl2dim pairs.
1462 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1463 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1464 return v.getDefiningOp() == def;
1465 });
1466 if (!sameDef)
1467 return failure();
1468
1469 bool oppositeDir = def.getDirection() != getDirection();
1470 bool sameOracle =
1471 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1472 bool sameCount = def.getNumResults() == getInCrds().size();
1473 if (!oppositeDir || !sameOracle || !sameCount)
1474 return failure();
1475
1476 // The definition produces the coordinates in the same order as the input
1477 // coordinates.
1478 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1479 [](auto valuePair) {
1480 auto [lhs, rhs] = valuePair;
1481 return lhs == rhs;
1482 });
1483
1484 if (!sameOrder)
1485 return failure();
1486 // l1 = dim2lvl (lvl2dim l0)
1487 // ==> l0
1488 results.append(def.getInCrds().begin(), def.getInCrds().end());
1489 return success();
1490}
1491
1492void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1493 int64_t index) {
1494 Value val = arith::ConstantIndexOp::create(builder, state.location, index);
1495 return build(builder, state, source, val);
1496}
1497
1498LogicalResult LvlOp::verify() {
1499 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1500 auto stt = getSparseTensorType(getSource());
1501 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1502 return emitError(
1503 "Level index exceeds the rank of the input sparse tensor");
1504 }
1505 return success();
1506}
1507
1508std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1509 return getConstantIntValue(getIndex());
1510}
1511
1512Speculation::Speculatability LvlOp::getSpeculatability() {
1513 auto constantIndex = getConstantLvlIndex();
1514 if (!constantIndex)
1516
1517 assert(constantIndex <
1518 cast<RankedTensorType>(getSource().getType()).getRank());
1520}
1521
1522OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1523 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1524 if (!lvlIndex)
1525 return {};
1526
1527 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1528 auto stt = getSparseTensorType(getSource());
1529 if (lvl >= stt.getLvlRank()) {
1530 // Follows the same convention used by tensor.dim operation. Out of bound
1531 // indices produce undefined behavior but are still valid IR. Don't choke on
1532 // them.
1533 return {};
1534 }
1535
1536 // Helper lambda to build an IndexAttr.
1537 auto getIndexAttr = [this](int64_t lvlSz) {
1538 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1539 };
1540
1541 SmallVector<Size> lvlShape = stt.getLvlShape();
1542 if (ShapedType::isStatic(lvlShape[lvl]))
1543 return getIndexAttr(lvlShape[lvl]);
1544
1545 return {};
1546}
1547
1548void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1549 SparseTensorEncodingAttr dstEnc, Value source) {
1550 auto srcStt = getSparseTensorType(source);
1551 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1552 SmallVector<int64_t> dstDimShape =
1553 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1554 auto dstTp =
1555 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1556 return build(odsBuilder, odsState, dstTp, source);
1557}
1558
1559LogicalResult ReinterpretMapOp::verify() {
1560 auto srcStt = getSparseTensorType(getSource());
1561 auto dstStt = getSparseTensorType(getDest());
1562 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1563 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1564
1565 if (srcLvlTps.size() != dstLvlTps.size())
1566 return emitError("Level rank mismatch between source/dest tensors");
1567
1568 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1569 if (srcLvlTp != dstLvlTp)
1570 return emitError("Level type mismatch between source/dest tensors");
1571
1572 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1573 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1574 return emitError("Crd/Pos width mismatch between source/dest tensors");
1575 }
1576
1577 if (srcStt.getElementType() != dstStt.getElementType())
1578 return emitError("Element type mismatch between source/dest tensors");
1579
1580 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1581 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1582 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1583 if (srcLvlSz != dstLvlSz) {
1584 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1585 // compatible to <3x4>? For now, we require all the level sizes to be
1586 // *exactly* matched for simplicity.
1587 return emitError("Level size mismatch between source/dest tensors");
1588 }
1589 }
1590
1591 return success();
1592}
1593
1594OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1595 if (getSource().getType() == getDest().getType())
1596 return getSource();
1597
1598 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1599 // A -> B, B -> A ==> A
1600 if (def.getSource().getType() == getDest().getType())
1601 return def.getSource();
1602 }
1603 return {};
1604}
1605
1606template <typename ToBufferOp>
1607static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1608 OpaqueProperties prop,
1609 RegionRange region,
1611 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1612 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1613 Type elemTp = nullptr;
1614 bool withStride = false;
1615 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1616 elemTp = stt.getPosType();
1617 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1618 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1619 elemTp = stt.getCrdType();
1620 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1621 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1622 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1623 elemTp = stt.getElementType();
1624 }
1625
1626 assert(elemTp && "unhandled operation.");
1627 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1628 bufShape.push_back(ShapedType::kDynamic);
1629
1630 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1631 stt.getContext(), ShapedType::kDynamic,
1632 {ShapedType::kDynamic})
1633 : StridedLayoutAttr();
1634 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1635 return success();
1636}
1637
1638LogicalResult ToPositionsOp::verify() {
1639 auto stt = getSparseTensorType(getTensor());
1640 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1641 return emitError("requested level is out of bounds");
1642 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1643 return emitError("unexpected type for positions");
1644 return success();
1645}
1646
1647LogicalResult
1648ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1649 ValueRange ops, DictionaryAttr attr,
1650 OpaqueProperties prop, RegionRange region,
1651 SmallVectorImpl<mlir::Type> &ret) {
1652 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1653}
1654
1655LogicalResult ToCoordinatesOp::verify() {
1656 auto stt = getSparseTensorType(getTensor());
1657 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1658 return emitError("requested level is out of bounds");
1659 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1660 return emitError("unexpected type for coordinates");
1661 return success();
1662}
1663
1664LogicalResult
1665ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1666 ValueRange ops, DictionaryAttr attr,
1667 OpaqueProperties prop, RegionRange region,
1668 SmallVectorImpl<mlir::Type> &ret) {
1669 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1670}
1671
1672LogicalResult ToCoordinatesBufferOp::verify() {
1673 auto stt = getSparseTensorType(getTensor());
1674 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1675 return emitError("expected sparse tensor with a COO region");
1676 return success();
1677}
1678
1679LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1680 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1681 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1682 SmallVectorImpl<mlir::Type> &ret) {
1683 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1684 ret);
1685}
1686
1687LogicalResult ToValuesOp::verify() {
1688 auto stt = getSparseTensorType(getTensor());
1689 auto mtp = getMemRefType(getResult());
1690 if (stt.getElementType() != mtp.getElementType())
1691 return emitError("unexpected mismatch in element types");
1692 return success();
1693}
1694
1695LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1696 std::optional<Location> loc,
1697 ValueRange ops, DictionaryAttr attr,
1698 OpaqueProperties prop,
1699 RegionRange region,
1700 SmallVectorImpl<mlir::Type> &ret) {
1701 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1702}
1703
1704LogicalResult ToSliceOffsetOp::verify() {
1705 auto rank = getSlice().getType().getRank();
1706 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1707 return emitError("requested dimension out of bound");
1708 return success();
1709}
1710
1711LogicalResult ToSliceStrideOp::verify() {
1712 auto rank = getSlice().getType().getRank();
1713 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1714 return emitError("requested dimension out of bound");
1715 return success();
1716}
1717
1718LogicalResult GetStorageSpecifierOp::verify() {
1719 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1720 getSpecifier(), getOperation());
1721}
1722
1723template <typename SpecifierOp>
1724static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1725 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1726}
1727
1728OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1729 const StorageSpecifierKind kind = getSpecifierKind();
1730 const auto lvl = getLevel();
1731 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1732 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1733 return op.getValue();
1734 return {};
1735}
1736
1737LogicalResult SetStorageSpecifierOp::verify() {
1738 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1739 getSpecifier(), getOperation());
1740}
1741
1742template <class T>
1743static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1744 const char *regionName,
1745 TypeRange inputTypes, Type outputType) {
1746 unsigned numArgs = region.getNumArguments();
1747 unsigned expectedNum = inputTypes.size();
1748 if (numArgs != expectedNum)
1749 return op->emitError() << regionName << " region must have exactly "
1750 << expectedNum << " arguments";
1751
1752 for (unsigned i = 0; i < numArgs; i++) {
1753 Type typ = region.getArgument(i).getType();
1754 if (typ != inputTypes[i])
1755 return op->emitError() << regionName << " region argument " << (i + 1)
1756 << " type mismatch";
1757 }
1758 Block &block = region.front();
1759 if (!block.mightHaveTerminator())
1760 return op->emitError() << regionName
1761 << " region must end with a terminator";
1762
1763 Operation *term = block.getTerminator();
1764 YieldOp yield = dyn_cast<YieldOp>(term);
1765 if (!yield)
1766 return op->emitError() << regionName
1767 << " region must end with sparse_tensor.yield";
1768 if (!yield.hasSingleResult() ||
1769 yield.getSingleResult().getType() != outputType)
1770 return op->emitError() << regionName << " region yield type mismatch";
1771
1772 return success();
1773}
1774
1775LogicalResult BinaryOp::verify() {
1776 NamedAttrList attrs = (*this)->getAttrs();
1777 Type leftType = getX().getType();
1778 Type rightType = getY().getType();
1779 Type outputType = getOutput().getType();
1780 Region &overlap = getOverlapRegion();
1781 Region &left = getLeftRegion();
1782 Region &right = getRightRegion();
1783
1784 // Check correct number of block arguments and return type for each
1785 // non-empty region.
1786 if (!overlap.empty()) {
1787 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1788 TypeRange{leftType, rightType}, outputType)))
1789 return failure();
1790 }
1791 if (!left.empty()) {
1792 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1793 outputType)))
1794 return failure();
1795 } else if (getLeftIdentity()) {
1796 if (leftType != outputType)
1797 return emitError("left=identity requires first argument to have the same "
1798 "type as the output");
1799 }
1800 if (!right.empty()) {
1801 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1802 outputType)))
1803 return failure();
1804 } else if (getRightIdentity()) {
1805 if (rightType != outputType)
1806 return emitError("right=identity requires second argument to have the "
1807 "same type as the output");
1808 }
1809 return success();
1810}
1811
1812LogicalResult UnaryOp::verify() {
1813 Type inputType = getX().getType();
1814 Type outputType = getOutput().getType();
1815
1816 // Check correct number of block arguments and return type for each
1817 // non-empty region.
1818 Region &present = getPresentRegion();
1819 if (!present.empty()) {
1820 if (failed(verifyNumBlockArgs(this, present, "present",
1821 TypeRange{inputType}, outputType)))
1822 return failure();
1823 }
1824 Region &absent = getAbsentRegion();
1825 if (!absent.empty()) {
1826 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1827 outputType)))
1828 return failure();
1829 // Absent branch can only yield invariant values.
1830 Block *absentBlock = &absent.front();
1831 Block *parent = getOperation()->getBlock();
1832 Value absentVal =
1833 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1834 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1835 if (arg.getOwner() == parent)
1836 return emitError("absent region cannot yield linalg argument");
1837 } else if (Operation *def = absentVal.getDefiningOp()) {
1838 if (!isa<arith::ConstantOp>(def) &&
1839 (def->getBlock() == absentBlock || def->getBlock() == parent))
1840 return emitError("absent region cannot yield locally computed value");
1841 }
1842 }
1843 return success();
1844}
1845
1846bool ConcatenateOp::needsExtraSort() {
1847 SparseTensorType dstStt = getSparseTensorType(*this);
1848 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1849 return false;
1850
1851 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1852 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1853 });
1854 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1855 // in all input/output buffers, and all input/output buffers have the same
1856 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1857 // CSC matrices along column).
1858 bool directLowerable =
1859 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1860 return !directLowerable;
1861}
1862
1863LogicalResult ConcatenateOp::verify() {
1864 const auto dstTp = getSparseTensorType(*this);
1865 const Dimension concatDim = getDimension();
1866 const Dimension dimRank = dstTp.getDimRank();
1867
1868 if (getInputs().size() <= 1)
1869 return emitError("Need at least two tensors to concatenate.");
1870
1871 if (concatDim >= dimRank)
1872 return emitError(llvm::formatv(
1873 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1874 concatDim, dimRank));
1875
1876 for (const auto &it : llvm::enumerate(getInputs())) {
1877 const auto i = it.index();
1878 const auto srcTp = getSparseTensorType(it.value());
1879 if (srcTp.hasDynamicDimShape())
1880 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1881 const Dimension srcDimRank = srcTp.getDimRank();
1882 if (srcDimRank != dimRank)
1883 return emitError(
1884 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1885 "from the output tensor (rank={2}).",
1886 i, srcDimRank, dimRank));
1887 }
1888
1889 for (Dimension d = 0; d < dimRank; d++) {
1890 const Size dstSh = dstTp.getDimShape()[d];
1891 if (d == concatDim) {
1892 if (ShapedType::isStatic(dstSh)) {
1893 // If we reach here, then all inputs have static shapes. So we
1894 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1895 // to avoid redundant assertions in the loop.
1896 Size sumSz = 0;
1897 for (const auto src : getInputs())
1898 sumSz += getSparseTensorType(src).getDimShape()[d];
1899 // If all dimension are statically known, the sum of all the input
1900 // dimensions should be equal to the output dimension.
1901 if (sumSz != dstSh)
1902 return emitError(
1903 "The concatenation dimension of the output tensor should be the "
1904 "sum of all the concatenation dimensions of the input tensors.");
1905 }
1906 } else {
1907 Size prev = dstSh;
1908 for (const auto src : getInputs()) {
1909 const auto sh = getSparseTensorType(src).getDimShape()[d];
1910 if (ShapedType::isStatic(prev) && sh != prev)
1911 return emitError("All dimensions (expect for the concatenating one) "
1912 "should be equal.");
1913 prev = sh;
1914 }
1915 }
1916 }
1917
1918 return success();
1919}
1920
1921void PushBackOp::build(OpBuilder &builder, OperationState &result,
1922 Value curSize, Value inBuffer, Value value) {
1923 build(builder, result, curSize, inBuffer, value, Value());
1924}
1925
1926LogicalResult PushBackOp::verify() {
1927 if (Value n = getN()) {
1928 std::optional<int64_t> nValue = getConstantIntValue(n);
1929 if (nValue && nValue.value() < 1)
1930 return emitOpError("n must be not less than 1");
1931 }
1932 return success();
1933}
1934
1935LogicalResult CompressOp::verify() {
1936 const auto stt = getSparseTensorType(getTensor());
1937 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1938 return emitOpError("incorrect number of coordinates");
1939 return success();
1940}
1941
1942void ForeachOp::build(
1943 OpBuilder &builder, OperationState &result, Value tensor,
1944 ValueRange initArgs, AffineMapAttr order,
1945 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1946 bodyBuilder) {
1947 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1948 // Builds foreach body.
1949 if (!bodyBuilder)
1950 return;
1951 const auto stt = getSparseTensorType(tensor);
1952 const Dimension dimRank = stt.getDimRank();
1953
1954 // Starts with `dimRank`-many coordinates.
1955 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1956 // Followed by one value.
1957 blockArgTypes.push_back(stt.getElementType());
1958 // Followed by the reduction variables.
1959 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1960
1961 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1962
1963 OpBuilder::InsertionGuard guard(builder);
1964 auto &region = *result.regions.front();
1965 Block *bodyBlock =
1966 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1967 bodyBuilder(builder, result.location,
1968 bodyBlock->getArguments().slice(0, dimRank),
1969 bodyBlock->getArguments()[dimRank],
1970 bodyBlock->getArguments().drop_front(dimRank + 1));
1971}
1972
1973LogicalResult ForeachOp::verify() {
1974 const auto t = getSparseTensorType(getTensor());
1975 const Dimension dimRank = t.getDimRank();
1976 const auto args = getBody()->getArguments();
1977
1978 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1979 return emitError("Level traverse order does not match tensor's level rank");
1980
1981 if (dimRank + 1 + getInitArgs().size() != args.size())
1982 return emitError("Unmatched number of arguments in the block");
1983
1984 if (getNumResults() != getInitArgs().size())
1985 return emitError("Mismatch in number of init arguments and results");
1986
1987 if (getResultTypes() != getInitArgs().getTypes())
1988 return emitError("Mismatch in types of init arguments and results");
1989
1990 // Cannot mark this const, because the getters aren't.
1991 auto yield = cast<YieldOp>(getBody()->getTerminator());
1992 if (yield.getNumOperands() != getNumResults() ||
1993 yield.getOperands().getTypes() != getResultTypes())
1994 return emitError("Mismatch in types of yield values and results");
1995
1996 const auto iTp = IndexType::get(getContext());
1997 for (Dimension d = 0; d < dimRank; d++)
1998 if (args[d].getType() != iTp)
1999 return emitError(
2000 llvm::formatv("Expecting Index type for argument at index {0}", d));
2001
2002 const auto elemTp = t.getElementType();
2003 const auto valueTp = args[dimRank].getType();
2004 if (elemTp != valueTp)
2005 return emitError(
2006 llvm::formatv("Unmatched element type between input tensor and "
2007 "block argument, expected:{0}, got: {1}",
2008 elemTp, valueTp));
2009 return success();
2010}
2011
2012OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2013 if (getSparseTensorEncoding(getInputCoo().getType()) ==
2014 getSparseTensorEncoding(getResultCoo().getType()))
2015 return getInputCoo();
2016
2017 return {};
2018}
2019
2020LogicalResult ReorderCOOOp::verify() {
2021 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2022 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2023
2024 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2025 return emitError("Expected COO sparse tensors only");
2026
2027 if (!srcStt.hasSameDimToLvl(dstStt))
2028 return emitError("Unmatched dim2lvl map between input and result COO");
2029
2030 if (srcStt.getPosType() != dstStt.getPosType() ||
2031 srcStt.getCrdType() != dstStt.getCrdType() ||
2032 srcStt.getElementType() != dstStt.getElementType())
2033 return emitError("Unmatched storage format between input and result COO");
2034
2035 return success();
2036}
2037
2038LogicalResult ReduceOp::verify() {
2039 Type inputType = getX().getType();
2040 Region &formula = getRegion();
2041 return verifyNumBlockArgs(this, formula, "reduce",
2042 TypeRange{inputType, inputType}, inputType);
2043}
2044
2045LogicalResult SelectOp::verify() {
2046 Builder b(getContext());
2047 Type inputType = getX().getType();
2048 Type boolType = b.getI1Type();
2049 Region &formula = getRegion();
2050 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2051 boolType);
2052}
2053
2054LogicalResult SortOp::verify() {
2055 AffineMap xPerm = getPermMap();
2056 uint64_t nx = xPerm.getNumDims();
2057 if (nx < 1)
2058 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2059
2060 if (!xPerm.isPermutation())
2061 return emitError(
2062 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2063
2064 // We can't check the size of the buffers when n or buffer dimensions aren't
2065 // compile-time constants.
2066 std::optional<int64_t> cn = getConstantIntValue(getN());
2067 if (!cn)
2068 return success();
2069
2070 // Verify dimensions.
2071 const auto checkDim = [&](Value v, Size minSize,
2072 const char *message) -> LogicalResult {
2073 const Size sh = getMemRefType(v).getShape()[0];
2074 if (ShapedType::isStatic(sh) && sh < minSize)
2075 return emitError(
2076 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2077 return success();
2078 };
2079 uint64_t n = cn.value();
2080 uint64_t ny = 0;
2081 if (auto nyAttr = getNyAttr())
2082 ny = nyAttr.getInt();
2083 if (failed(checkDim(getXy(), n * (nx + ny),
2084 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2085 return failure();
2086 for (Value opnd : getYs())
2087 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2088 return failure();
2089
2090 return success();
2091}
2092
2093//===----------------------------------------------------------------------===//
2094// Sparse Tensor Iteration Operations.
2095//===----------------------------------------------------------------------===//
2096
2097IterSpaceType IteratorType::getIterSpaceType() const {
2098 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2099 getHiLvl());
2100}
2101
2102IteratorType IterSpaceType::getIteratorType() const {
2103 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2104}
2105
2106/// Parses a level range in the form "$lo `to` $hi"
2107/// or simply "$lo" if $hi - $lo = 1
2108static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2109 Level &lvlHi) {
2110 if (parser.parseInteger(lvlLo))
2111 return failure();
2112
2113 if (succeeded(parser.parseOptionalKeyword("to"))) {
2114 if (parser.parseInteger(lvlHi))
2115 return failure();
2116 } else {
2117 lvlHi = lvlLo + 1;
2118 }
2119
2120 if (lvlHi <= lvlLo)
2121 return parser.emitError(parser.getNameLoc(),
2122 "expect larger level upper bound than lower bound");
2123
2124 return success();
2125}
2126
2127/// Parses a level range in the form "$lo `to` $hi"
2128/// or simply "$lo" if $hi - $lo = 1
2129static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2130 IntegerAttr &lvlHiAttr) {
2131 Level lvlLo, lvlHi;
2132 if (parseLevelRange(parser, lvlLo, lvlHi))
2133 return failure();
2134
2135 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2136 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2137 return success();
2138}
2139
2140/// Prints a level range in the form "$lo `to` $hi"
2141/// or simply "$lo" if $hi - $lo = 1
2142static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2143
2144 if (lo + 1 == hi)
2145 p << lo;
2146 else
2147 p << lo << " to " << hi;
2148}
2149
2150/// Prints a level range in the form "$lo `to` $hi"
2151/// or simply "$lo" if $hi - $lo = 1
2152static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2153 IntegerAttr lvlHi) {
2154 unsigned lo = lvlLo.getValue().getZExtValue();
2155 unsigned hi = lvlHi.getValue().getZExtValue();
2156 printLevelRange(p, lo, hi);
2157}
2158
2159/// Parses a list of `optional` defined list in the form of
2160/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2161/// corresponding value is not defined (e.g., to represent an undefined
2162/// coordinate in the sparse iteration space).
2163static ParseResult parseOptionalDefinedList(
2164 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2166 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2168 unsigned cnt = 0;
2169 ParseResult crdList =
2170 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2171 if (parser.parseOptionalKeyword("_")) {
2172 if (parser.parseArgument(definedArgs.emplace_back()))
2173 return failure();
2174 definedSet.set(cnt);
2175 }
2176 cnt += 1;
2177 return success();
2178 });
2179
2180 if (cnt > maxCnt)
2181 return parser.emitError(parser.getNameLoc(),
2182 "parsed more value than expected.");
2183
2184 if (failed(crdList)) {
2185 return parser.emitError(
2186 parser.getNameLoc(),
2187 "expecting SSA value or \"_\" for level coordinates");
2188 }
2189 assert(definedArgs.size() == definedSet.count());
2190 return success();
2191}
2192
2193static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2194 Block::BlockArgListType blocksArgs,
2195 I64BitSet definedSet) {
2196 if (definedSet.empty())
2197 return;
2198
2199 for (unsigned i = 0; i < size; i++) {
2200 if (definedSet[i]) {
2201 p << blocksArgs.front();
2202 blocksArgs = blocksArgs.drop_front();
2203 } else {
2204 p << "_";
2205 }
2206 if (i != size - 1)
2207 p << ", ";
2208 }
2209 assert(blocksArgs.empty());
2210}
2211
2212static ParseResult
2215 // Parse "at(%crd0, _, ...)"
2216 I64BitSet crdUsedLvlSet;
2217 if (succeeded(parser.parseOptionalKeyword("at")) &&
2218 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2219 return failure();
2220
2221 // Always use IndexType for the coordinate.
2222 for (auto &coord : coords)
2223 coord.type = parser.getBuilder().getIndexType();
2224
2225 // Set the CrdUsedLvl bitset.
2226 state.addAttribute("crdUsedLvls",
2227 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2228 return success();
2229}
2230
2231static ParseResult
2237
2238 // Parse "%iters, ... in %spaces, ..."
2239 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2240 parser.parseOperandList(spaces))
2241 return failure();
2242
2243 if (iterators.size() != spaces.size())
2244 return parser.emitError(
2245 parser.getNameLoc(),
2246 "mismatch in number of sparse iterators and sparse spaces");
2247
2249 if (failed(parseUsedCoordList(parser, state, coords)))
2250 return failure();
2251 size_t numCrds = coords.size();
2252
2253 // Parse "iter_args(%arg = %init, ...)"
2254 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2255 if (hasIterArgs)
2256 if (parser.parseAssignmentList(blockArgs, initArgs))
2257 return failure();
2258
2259 blockArgs.append(coords);
2260
2261 SmallVector<Type> iterSpaceTps;
2262 // parse ": sparse_tensor.iter_space -> ret"
2263 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2264 return failure();
2265 if (iterSpaceTps.size() != spaces.size())
2266 return parser.emitError(parser.getNameLoc(),
2267 "mismatch in number of iteration space operands "
2268 "and iteration space types");
2269
2270 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2271 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2272 if (!spaceTp)
2273 return parser.emitError(parser.getNameLoc(),
2274 "expected sparse_tensor.iter_space type for "
2275 "iteration space operands");
2276 it.type = spaceTp.getIteratorType();
2277 }
2278
2279 if (hasIterArgs)
2280 if (parser.parseArrowTypeList(state.types))
2281 return failure();
2282
2283 // Resolves input operands.
2284 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2285 state.operands))
2286 return failure();
2287
2288 if (hasIterArgs) {
2289 // Strip off leading args that used for coordinates.
2290 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2291 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2292 return parser.emitError(
2293 parser.getNameLoc(),
2294 "mismatch in number of iteration arguments and return values");
2295 }
2296
2297 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2298 it.type = tp;
2299 if (parser.resolveOperand(init, tp, state.operands))
2300 return failure();
2301 }
2302 }
2303 return success();
2304}
2305
2306static ParseResult
2308 SmallVectorImpl<Value> &spacesVals,
2310
2311 // Parse "(%spaces, ...)"
2314 return failure();
2315
2317 if (failed(parseUsedCoordList(parser, state, coords)))
2318 return failure();
2319 size_t numCrds = coords.size();
2320
2321 // Parse "iter_args(%arg = %init, ...)"
2323 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2324 if (hasIterArgs)
2325 if (parser.parseAssignmentList(blockArgs, initArgs))
2326 return failure();
2327 blockArgs.append(coords);
2328
2329 SmallVector<Type> iterSpaceTps;
2330 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2331 if (parser.parseColon() || parser.parseLParen() ||
2332 parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2333 return failure();
2334
2335 if (iterSpaceTps.size() != spaces.size())
2336 return parser.emitError(parser.getNameLoc(),
2337 "mismatch in number of iteration space operands "
2338 "and iteration space types");
2339
2340 if (hasIterArgs)
2341 if (parser.parseArrowTypeList(state.types))
2342 return failure();
2343
2344 // Resolves input sparse iteration spaces.
2345 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2346 spacesVals))
2347 return failure();
2348 state.operands.append(spacesVals);
2349
2350 if (hasIterArgs) {
2351 // Strip off trailing args that used for coordinates.
2352 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2353 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2354 return parser.emitError(
2355 parser.getNameLoc(),
2356 "mismatch in number of iteration arguments and return values");
2357 }
2358
2359 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2360 it.type = tp;
2361 if (parser.resolveOperand(init, tp, state.operands))
2362 return failure();
2363 }
2364 }
2365 return success();
2366}
2367
2368LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2369 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2370 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2371 SmallVectorImpl<mlir::Type> &ret) {
2372
2373 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2374 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2375 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2376 adaptor.getHiLvl()));
2377 return success();
2378}
2379
2380LogicalResult ExtractIterSpaceOp::verify() {
2381 if (getLoLvl() >= getHiLvl())
2382 return emitOpError("expected smaller level low than level high");
2383
2384 TypedValue<IteratorType> pIter = getParentIter();
2385 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2386 return emitOpError(
2387 "parent iterator should be specified iff level lower bound equals 0");
2388 }
2389
2390 if (pIter) {
2391 IterSpaceType spaceTp = getExtractedSpace().getType();
2392 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2393 return emitOpError(
2394 "mismatch in parent iterator encoding and iteration space encoding.");
2395
2396 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2397 return emitOpError("parent iterator should be used to extract an "
2398 "iteration space from a consecutive level.");
2399 }
2400
2401 return success();
2402}
2403
2404LogicalResult ExtractValOp::verify() {
2405 auto stt = getSparseTensorType(getTensor());
2406 auto itTp = getIterator().getType();
2407
2408 if (stt.getEncoding() != itTp.getEncoding())
2409 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2410
2411 if (stt.getLvlRank() != itTp.getHiLvl())
2412 return emitOpError("must use last-level iterator to extract values. ");
2413
2414 return success();
2415}
2416
2417struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2419
2420 LogicalResult matchAndRewrite(IterateOp iterateOp,
2421 PatternRewriter &rewriter) const override {
2422 I64BitSet newUsedLvls(0);
2423 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2424 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2425 if (auto crd = iterateOp.getLvlCrd(i)) {
2426 if (crd->getUsers().empty())
2427 toRemove.set(crd->getArgNumber());
2428 else
2429 newUsedLvls.set(i);
2430 }
2431 }
2432
2433 // All coordinates are used.
2434 if (toRemove.none())
2435 return failure();
2436
2437 rewriter.startOpModification(iterateOp);
2438 iterateOp.setCrdUsedLvls(newUsedLvls);
2439 iterateOp.getBody()->eraseArguments(toRemove);
2440 rewriter.finalizeOpModification(iterateOp);
2441 return success();
2442 }
2443};
2444
2445void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2446 mlir::MLIRContext *context) {
2447 results.add<RemoveUnusedLvlCrds>(context);
2448}
2449
2450void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2451 Value iterSpace, ValueRange initArgs) {
2452 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2453 // All ones.
2454 I64BitSet set((1 << rank) - 1);
2455 return build(builder, odsState, iterSpace, initArgs, set);
2456}
2457
2458void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2459 Value iterSpace, ValueRange initArgs,
2460 I64BitSet crdUsedLvls) {
2461 OpBuilder::InsertionGuard guard(builder);
2462
2463 odsState.addOperands(iterSpace);
2464 odsState.addOperands(initArgs);
2465 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2466 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2467 Region *bodyRegion = odsState.addRegion();
2468 odsState.addTypes(initArgs.getTypes());
2469 Block *bodyBlock = builder.createBlock(bodyRegion);
2470
2471 // Starts with a list of user-provided loop arguments.
2472 for (Value v : initArgs)
2473 bodyBlock->addArgument(v.getType(), v.getLoc());
2474
2475 // Follows by a list of used coordinates.
2476 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2477 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2478
2479 // Ends with sparse iterator
2480 bodyBlock->addArgument(
2481 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2482 odsState.location);
2483}
2484
2485ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2486 OpAsmParser::Argument iterator;
2487 OpAsmParser::UnresolvedOperand iterSpace;
2488
2489 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2490 if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2491 return failure();
2492 if (iters.size() != 1)
2493 return parser.emitError(parser.getNameLoc(),
2494 "expected only one iterator/iteration space");
2495
2496 iterArgs.append(iters);
2497 Region *body = result.addRegion();
2498 if (parser.parseRegion(*body, iterArgs))
2499 return failure();
2500
2501 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2502
2503 // Parse the optional attribute list.
2504 if (parser.parseOptionalAttrDict(result.attributes))
2505 return failure();
2506
2507 return success();
2508}
2509
2510/// Prints the initialization list in the form of
2511/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2512/// where 'inner' values are assumed to be region arguments and 'outer' values
2513/// are regular SSA values.
2515 Block::BlockArgListType blocksArgs,
2516 ValueRange initializers,
2517 StringRef prefix = "") {
2518 assert(blocksArgs.size() == initializers.size() &&
2519 "expected same length of arguments and initializers");
2520 if (initializers.empty())
2521 return;
2522
2523 p << prefix << '(';
2524 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2525 p << std::get<0>(it) << " = " << std::get<1>(it);
2526 });
2527 p << ")";
2528}
2529
2530template <typename SparseLoopOp>
2531static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2532 if (op.getInitArgs().size() != op.getNumResults()) {
2533 return op.emitOpError(
2534 "mismatch in number of loop-carried values and defined values");
2535 }
2536 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2537 return op.emitOpError("required out-of-bound coordinates");
2538
2539 return success();
2540}
2541
2542LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2543LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2544
2545void IterateOp::print(OpAsmPrinter &p) {
2546 p << " " << getIterator() << " in " << getIterSpace();
2547 if (!getCrdUsedLvls().empty()) {
2548 p << " at(";
2549 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2550 p << ")";
2551 }
2552 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2553
2554 p << " : " << getIterSpace().getType() << " ";
2555 if (!getInitArgs().empty())
2556 p.printArrowTypeList(getInitArgs().getTypes());
2557
2558 p << " ";
2559 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2560 /*printBlockTerminators=*/!getInitArgs().empty());
2561}
2562
2563LogicalResult IterateOp::verifyRegions() {
2564 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2565 return emitOpError("mismatch in iterator and iteration space type");
2566 if (getNumRegionIterArgs() != getNumResults())
2567 return emitOpError(
2568 "mismatch in number of basic block args and defined values");
2569
2570 auto initArgs = getInitArgs();
2571 auto iterArgs = getRegionIterArgs();
2572 auto yieldVals = getYieldedValues();
2573 auto opResults = getResults();
2574 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2575 opResults.size()})) {
2576 return emitOpError() << "number mismatch between iter args and results.";
2577 }
2578
2579 for (auto [i, init, iter, yield, ret] :
2580 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2581 if (init.getType() != ret.getType())
2582 return emitOpError() << "types mismatch between " << i
2583 << "th iter operand and defined value";
2584 if (iter.getType() != ret.getType())
2585 return emitOpError() << "types mismatch between " << i
2586 << "th iter region arg and defined value";
2587 if (yield.getType() != ret.getType())
2588 return emitOpError() << "types mismatch between " << i
2589 << "th yield value and defined value";
2590 }
2591
2592 return success();
2593}
2594
2595/// OpInterfaces' methods implemented by IterateOp.
2596SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2597
2598MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2599 return getInitArgsMutable();
2600}
2601
2602Block::BlockArgListType IterateOp::getRegionIterArgs() {
2603 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2604}
2605
2606std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2607 return cast<sparse_tensor::YieldOp>(
2608 getRegion().getBlocks().front().getTerminator())
2609 .getResultsMutable();
2610}
2611
2612std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2613
2614OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2615 return getInitArgs();
2616}
2617
2618void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2619 SmallVectorImpl<RegionSuccessor> &regions) {
2620 // Both the operation itself and the region may be branching into the body
2621 // or back into the operation itself.
2622 regions.push_back(RegionSuccessor(&getRegion()));
2623 // It is possible for loop not to enter the body.
2624 regions.push_back(RegionSuccessor::parent());
2625}
2626
2627ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2628 return successor.isParent() ? ValueRange(getResults())
2629 : ValueRange(getRegionIterArgs());
2630}
2631
2632void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2633 ValueRange iterSpaces, ValueRange initArgs,
2634 unsigned numCases) {
2635 unsigned rank =
2636 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2637 // All ones.
2638 I64BitSet set((1 << rank) - 1);
2639 // Generates all-zero case bits (they only serve as placeholders), which are
2640 // supposed to be overriden later. We need to preallocate all the regions as
2641 // mlir::Region cannot be dynamically added later after the operation is
2642 // created.
2643 SmallVector<int64_t> caseBits(numCases, 0);
2644 ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2645 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2646 initArgs, set, cases,
2647 /*caseRegionsCount=*/numCases);
2648}
2649
2650ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2651
2652 SmallVector<Value> spaces;
2653 // The block argument list of each regions, it is arranged in the order of
2654 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2655 SmallVector<OpAsmParser::Argument> blockArgs;
2656 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2657 return failure();
2658
2659 result.addAttribute("operandSegmentSizes",
2661 {static_cast<int32_t>(spaces.size()),
2662 static_cast<int32_t>(result.types.size())}));
2663
2664 SmallVector<Attribute> cases;
2665 while (succeeded(parser.parseOptionalKeyword("case"))) {
2666 // Parse one region per case.
2667 I64BitSet definedItSet;
2668 SmallVector<OpAsmParser::Argument> definedIts;
2669 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2670 spaces.size(), OpAsmParser::Delimiter::None))
2671 return failure();
2672
2673 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2674
2675 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2676 // Resolve the iterator type based on the iteration space type.
2677 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2678 definedIts[i].type = spaceTp.getIteratorType();
2679 }
2680 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2681 Region *body = result.addRegion();
2682 if (parser.parseRegion(*body, definedIts))
2683 return failure();
2684
2685 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2686 }
2687
2688 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2689
2690 // Parse the optional attribute list.
2691 if (parser.parseOptionalAttrDict(result.attributes))
2692 return failure();
2693
2694 return success();
2695}
2696
2697void CoIterateOp::print(OpAsmPrinter &p) {
2698 p << " (";
2699 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2700 p << ")";
2701
2702 if (!getCrdUsedLvls().empty()) {
2703 p << " at(";
2704 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2705 p << ")";
2706 }
2707
2708 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2709
2710 p << " : (" << getIterSpaces().getTypes() << ")";
2711 if (!getInitArgs().empty())
2712 p.printArrowTypeList(getInitArgs().getTypes());
2713
2714 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2715 p.printNewline();
2716 p << "case ";
2717 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2718 getRegionDefinedSpace(idx));
2719 p << " ";
2720 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2721 /*printBlockTerminators=*/!getInitArgs().empty());
2722 }
2723}
2724
2725ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2726 return cast<sparse_tensor::YieldOp>(
2727 getRegion(regionIdx).getBlocks().front().getTerminator())
2728 .getResults();
2729}
2730
2731LogicalResult CoIterateOp::verifyRegions() {
2732 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2733 if (getNumRegionIterArgs() != getNumResults())
2734 return emitOpError(
2735 "mismatch in number of basic block args and defined values");
2736
2737 auto initArgs = getInitArgs();
2738 auto iterArgs = getRegionIterArgs(r);
2739 auto yieldVals = getYieldedValues(r);
2740 auto opResults = getResults();
2741 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2742 opResults.size()})) {
2743 return emitOpError()
2744 << "number mismatch between iter args and results on " << r
2745 << "th region";
2746 }
2747
2748 for (auto [i, init, iter, yield, ret] :
2749 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2750 if (init.getType() != ret.getType())
2751 return emitOpError()
2752 << "types mismatch between " << i
2753 << "th iter operand and defined value on " << r << "th region";
2754 if (iter.getType() != ret.getType())
2755 return emitOpError() << "types mismatch between " << i
2756 << "th iter region arg and defined value on " << r
2757 << "th region";
2758 if (yield.getType() != ret.getType())
2759 return emitOpError()
2760 << "types mismatch between " << i
2761 << "th yield value and defined value on " << r << "th region";
2762 }
2763 }
2764
2765 auto cases = getRegionDefinedSpaces();
2766 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2767 if (set.size() != getNumRegions())
2768 return emitOpError("contains duplicated cases.");
2769
2770 return success();
2771}
2772
2773SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2774 SmallVector<Region *> ret;
2775 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2776 for (Region &r : getCaseRegions())
2777 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2778 ret.push_back(&r);
2779
2780 return ret;
2781}
2782
2783//===----------------------------------------------------------------------===//
2784// Sparse Tensor Dialect Setups.
2785//===----------------------------------------------------------------------===//
2786
2787/// Materialize a single constant operation from a given attribute value with
2788/// the desired resultant type.
2789Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2790 Attribute value, Type type,
2791 Location loc) {
2792 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2793 return op;
2794 return nullptr;
2795}
2796
2797void SparseTensorDialect::initialize() {
2798 addAttributes<
2799#define GET_ATTRDEF_LIST
2800#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2801 >();
2802 addTypes<
2803#define GET_TYPEDEF_LIST
2804#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2805 >();
2806 addOperations<
2807#define GET_OP_LIST
2808#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2809 >();
2810 declarePromisedInterfaces<
2811 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2812 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2813 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2814}
2815
2816#define GET_OP_CLASSES
2817#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2818
2819#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition Enums.h:402
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
bool isOrderedLT(LevelType lt)
Definition Enums.h:425
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
Definition Enums.h:442
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition Enums.h:413
uint64_t getM(LevelType lt)
Definition Enums.h:443
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition Enums.h:299