MLIR 22.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
17
22#include "mlir/IR/Builders.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144 return;
145}
146
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
156
157 const Type specType = StorageSpecifierType::get(stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164
165 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
184 unsigned numFields = 0;
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
194 unsigned numFields = 0; // one value memref
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl) const {
210 unsigned stride = 1;
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(getOffset());
269 os << ", ";
270 os << getStaticString(getSize());
271 os << ", ";
272 os << getStaticString(getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(printer.getStream());
278}
279
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 parser.getCurrentLocation(),
287 "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(parser.parseLParen()) ||
302 failed(parseOptionalStaticSlice(offset, parser)) ||
303 failed(parser.parseComma()) ||
304 failed(parseOptionalStaticSlice(size, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(stride, parser)) ||
307 failed(parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311 offset, size, stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(0, 0);
354}
355
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
362}
363
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365 return withExplicitVal(Attribute());
366}
367
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
374}
375
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377 return withImplicitVal(Attribute());
378}
379
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385}
386
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389}
390
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
395}
396
397bool SparseTensorEncodingAttr::isAllDense() const {
398 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399}
400
401bool SparseTensorEncodingAttr::isAllOrdered() const {
402 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403}
404
405Type SparseTensorEncodingAttr::getCrdElemType() const {
406 if (!getImpl())
407 return nullptr;
408 if (getCrdWidth())
409 return IntegerType::get(getContext(), getCrdWidth());
410 return IndexType::get(getContext());
411}
412
413Type SparseTensorEncodingAttr::getPosElemType() const {
414 if (!getImpl())
415 return nullptr;
416 if (getPosWidth())
417 return IntegerType::get(getContext(), getPosWidth());
418 return IndexType::get(getContext());
419}
420
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape) const {
423 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424 return MemRefType::get(shape, getCrdElemType());
425}
426
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape) const {
429 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430 return MemRefType::get(shape, getPosElemType());
431}
432
433bool SparseTensorEncodingAttr::isIdentity() const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435}
436
437bool SparseTensorEncodingAttr::isPermutation() const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439}
440
441Dimension SparseTensorEncodingAttr::getDimRank() const {
442 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445}
446
447Level SparseTensorEncodingAttr::getLvlRank() const {
448 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
450}
451
452LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453 if (!getImpl())
454 return LevelFormat::Batch;
455 assert(l < getLvlRank() && "Level is out of bounds");
456 return getLvlTypes()[l];
457}
458
459bool SparseTensorEncodingAttr::isSlice() const {
460 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
462}
463
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466 assert(isSlice() && "Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() && "Dimension is out of bounds");
469 return dimSlices[dim];
470}
471
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474 return getDimSlice(dim).getStaticOffset();
475}
476
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479 return getDimSlice(dim).getStaticStride();
480}
481
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484 return getStaticDimSliceOffset(toDim(*this, lvl));
485}
486
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489 return getStaticDimSliceStride(toDim(*this, lvl));
490}
491
492SmallVector<int64_t>
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir) const {
495 if (isIdentity())
496 return SmallVector<int64_t>(srcShape);
497
498 SmallVector<int64_t> ret;
499 unsigned rank =
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501 ret.reserve(rank);
502
503 if (isPermutation()) {
504 for (unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506 : toLvl(*this, r);
507 ret.push_back(srcShape[trans]);
508 }
509 return ret;
510 }
511
512 // Handle non-permutation maps.
513 AffineMap transMap =
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
520 // Push back the max coordinate for the given dimension/level size.
521 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
522 } else {
523 // A dynamic size, use a AffineDimExpr to symbolize the value.
524 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
525 }
526 };
527
528 for (AffineExpr exp : transMap.getResults()) {
529 // Do constant propagation on the affine map.
530 AffineExpr evalExp =
531 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
532 // use llvm namespace here to avoid ambiguity
533 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534 ret.push_back(c.getValue() + 1);
535 } else {
536 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
537 mod && mod.getKind() == AffineExprKind::Mod) {
538 // We can still infer a static bound for expressions in form
539 // "d % constant" since d % constant \in [0, constant).
540 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541 ret.push_back(bound.getValue());
542 continue;
543 }
544 }
545 ret.push_back(ShapedType::kDynamic);
546 }
547 }
548 assert(ret.size() == rank);
549 return ret;
550}
551
553SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
554 ValueRange crds,
555 CrdTransDirectionKind dir) const {
556 if (!getImpl())
557 return crds;
558
559 SmallVector<Type> retType(
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
561 builder.getIndexType());
562 auto transOp =
563 CrdTranslateOp::create(builder, loc, retType, crds, dir, *this);
564 return transOp.getOutCrds();
565}
566
567Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
568 // Open "<{" part.
569 if (failed(parser.parseLess()))
570 return {};
571 if (failed(parser.parseLBrace()))
572 return {};
573
574 // Process the data from the parsed dictionary value into struct-like data.
575 SmallVector<LevelType> lvlTypes;
576 SmallVector<SparseTensorDimSliceAttr> dimSlices;
577 AffineMap dimToLvl = {};
578 AffineMap lvlToDim = {};
579 unsigned posWidth = 0;
580 unsigned crdWidth = 0;
581 Attribute explicitVal;
582 Attribute implicitVal;
583 StringRef attrName;
584 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
585 "explicitVal", "implicitVal"};
586 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
587 // Detect admissible keyword.
588 auto *it = find(keys, attrName);
589 if (it == keys.end()) {
590 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
591 return {};
592 }
593 unsigned keyWordIndex = it - keys.begin();
594 // Consume the `=` after keys
595 if (failed(parser.parseEqual()))
596 return {};
597 // Dispatch on keyword.
598 switch (keyWordIndex) {
599 case 0: { // map
600 ir_detail::DimLvlMapParser cParser(parser);
601 auto res = cParser.parseDimLvlMap();
602 if (failed(res))
603 return {};
604 const auto &dlm = *res;
605
606 const Level lvlRank = dlm.getLvlRank();
607 for (Level lvl = 0; lvl < lvlRank; lvl++)
608 lvlTypes.push_back(dlm.getLvlType(lvl));
609
610 const Dimension dimRank = dlm.getDimRank();
611 for (Dimension dim = 0; dim < dimRank; dim++)
612 dimSlices.push_back(dlm.getDimSlice(dim));
613 // NOTE: the old syntax requires an all-or-nothing approach to
614 // `dimSlices`; therefore, if any slice actually exists then we need
615 // to convert null-DSA into default/nop DSA.
616 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617 return static_cast<bool>(slice.getImpl());
618 };
619 if (llvm::any_of(dimSlices, isDefined)) {
620 const auto defaultSlice =
621 SparseTensorDimSliceAttr::get(parser.getContext());
622 for (Dimension dim = 0; dim < dimRank; dim++)
623 if (!isDefined(dimSlices[dim]))
624 dimSlices[dim] = defaultSlice;
625 } else {
626 dimSlices.clear();
627 }
628
629 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
630 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
631 break;
632 }
633 case 1: { // posWidth
634 Attribute attr;
635 if (failed(parser.parseAttribute(attr)))
636 return {};
637 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
638 if (!intAttr) {
639 parser.emitError(parser.getNameLoc(),
640 "expected an integral position bitwidth");
641 return {};
642 }
643 posWidth = intAttr.getInt();
644 break;
645 }
646 case 2: { // crdWidth
647 Attribute attr;
648 if (failed(parser.parseAttribute(attr)))
649 return {};
650 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
651 if (!intAttr) {
652 parser.emitError(parser.getNameLoc(),
653 "expected an integral index bitwidth");
654 return {};
655 }
656 crdWidth = intAttr.getInt();
657 break;
658 }
659 case 3: { // explicitVal
660 Attribute attr;
661 if (failed(parser.parseAttribute(attr)))
662 return {};
663 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
664 explicitVal = result;
665 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
666 explicitVal = result;
667 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668 explicitVal = result;
669 } else {
670 parser.emitError(parser.getNameLoc(),
671 "expected a numeric value for explicitVal");
672 return {};
673 }
674 break;
675 }
676 case 4: { // implicitVal
677 Attribute attr;
678 if (failed(parser.parseAttribute(attr)))
679 return {};
680 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
681 implicitVal = result;
682 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
683 implicitVal = result;
684 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685 implicitVal = result;
686 } else {
687 parser.emitError(parser.getNameLoc(),
688 "expected a numeric value for implicitVal");
689 return {};
690 }
691 break;
692 }
693 } // switch
694 // Only last item can omit the comma.
695 if (parser.parseOptionalComma().failed())
696 break;
697 }
698
699 // Close "}>" part.
700 if (failed(parser.parseRBrace()))
701 return {};
702 if (failed(parser.parseGreater()))
703 return {};
704
705 // Construct struct-like storage for attribute.
706 if (!lvlToDim || lvlToDim.isEmpty()) {
707 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
708 }
709 return parser.getChecked<SparseTensorEncodingAttr>(
710 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711 explicitVal, implicitVal, dimSlices);
712}
713
714void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
715 auto map = static_cast<AffineMap>(getDimToLvl());
716 // Empty affine map indicates identity map
717 if (!map)
718 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
719 printer << "<{ map = ";
720 printSymbols(map, printer);
721 printer << '(';
722 printDimensions(map, printer, getDimSlices());
723 printer << ") -> (";
724 printLevels(map, printer, getLvlTypes());
725 printer << ')';
726 // Print remaining members only for non-default values.
727 if (getPosWidth())
728 printer << ", posWidth = " << getPosWidth();
729 if (getCrdWidth())
730 printer << ", crdWidth = " << getCrdWidth();
731 if (getExplicitVal()) {
732 printer << ", explicitVal = " << getExplicitVal();
733 }
734 if (getImplicitVal())
735 printer << ", implicitVal = " << getImplicitVal();
736 printer << " }>";
737}
738
739void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
740 AsmPrinter &printer) const {
741 if (map.getNumSymbols() == 0)
742 return;
743 printer << '[';
744 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
745 printer << 's' << i << ", ";
746 if (map.getNumSymbols() >= 1)
747 printer << 's' << map.getNumSymbols() - 1;
748 printer << ']';
749}
750
751void SparseTensorEncodingAttr::printDimensions(
752 AffineMap &map, AsmPrinter &printer,
753 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
754 if (!dimSlices.empty()) {
755 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
756 printer << 'd' << i << " : " << dimSlices[i] << ", ";
757 if (map.getNumDims() >= 1) {
758 printer << 'd' << map.getNumDims() - 1 << " : "
759 << dimSlices[map.getNumDims() - 1];
760 }
761 } else {
762 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
763 printer << 'd' << i << ", ";
764 if (map.getNumDims() >= 1)
765 printer << 'd' << map.getNumDims() - 1;
766 }
767}
768
769void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
770 ArrayRef<LevelType> lvlTypes) const {
771 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
772 map.getResult(i).print(printer.getStream());
773 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
774 }
775 if (map.getNumResults() >= 1) {
776 auto lastIndex = map.getNumResults() - 1;
777 map.getResult(lastIndex).print(printer.getStream());
778 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
779 }
780}
781
782LogicalResult SparseTensorEncodingAttr::verify(
783 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
784 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
785 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
786 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
787 if (!acceptBitWidth(posWidth))
788 return emitError() << "unexpected position bitwidth: " << posWidth;
789 if (!acceptBitWidth(crdWidth))
790 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
791
792 // Verify every COO segment.
793 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
794 while (it != lvlTypes.end()) {
795 if (it == lvlTypes.begin() ||
797 return emitError() << "expected compressed or loose_compressed level "
798 "before singleton level";
799
800 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
801 if (!std::all_of(it, curCOOEnd, isSingletonLT))
802 return emitError() << "expected all singleton lvlTypes "
803 "following a singleton level";
804 // We can potentially support mixed SoA/AoS singleton levels.
805 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
806 return it->isa<LevelPropNonDefault::SoA>() ==
808 })) {
809 return emitError() << "expected all singleton lvlTypes stored in the "
810 "same memory layout (SoA vs AoS).";
811 }
812 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
813 }
814
815 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
816 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
817 return emitError() << "Batch lvlType can only be leading levels.";
818
819 // SoA property can only be applied on singleton level.
820 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
821 return lt.isa<LevelPropNonDefault::SoA>();
822 });
823 if (llvm::any_of(soaLvls, [](LevelType lt) {
824 return !lt.isa<LevelFormat::Singleton>();
825 })) {
826 return emitError() << "SoA is only applicable to singleton lvlTypes.";
827 }
828
829 // TODO: audit formats that actually are supported by backend.
830 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
831 it != std::end(lvlTypes)) {
832 if (it != lvlTypes.end() - 1)
833 return emitError() << "expected n_out_of_m to be the last level type";
834 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
835 return emitError() << "expected all dense lvlTypes "
836 "before a n_out_of_m level";
837 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
838 if (!isBlockSparsity(dimToLvl)) {
839 return emitError()
840 << "expected 1xm block structure for n_out_of_m level";
841 }
842 auto sizes = getBlockSize(dimToLvl);
843 unsigned coefficient = 0;
844 for (const auto &elem : sizes) {
845 if (elem != 0) {
846 if (elem != coefficient && coefficient != 0) {
847 return emitError() << "expected only one blocked level "
848 "with the same coefficients";
849 }
850 coefficient = elem;
851 }
852 }
853 if (coefficient != getM(*it)) {
854 return emitError() << "expected coeffiencts of Affine expressions "
855 "to be equal to m of n_out_of_m level";
856 }
857 }
858 }
859 // Before we can check that the level-rank is consistent/coherent
860 // across all fields, we need to define it. The source-of-truth for
861 // the `getLvlRank` method is the length of the level-types array,
862 // since it must always be provided and have full rank; therefore we
863 // use that same source-of-truth here.
864 const Level lvlRank = lvlTypes.size();
865 if (lvlRank == 0)
866 return emitError() << "expected a non-empty array for lvlTypes";
867 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
868 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
869 if (dimToLvl) {
870 if (dimToLvl.getNumResults() != lvlRank)
871 return emitError()
872 << "level-rank mismatch between dimToLvl and lvlTypes: "
873 << dimToLvl.getNumResults() << " != " << lvlRank;
874 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
875 // Symbols can't be inferred but are acceptable.
876 if (!inferRes && dimToLvl.getNumSymbols() == 0)
877 return emitError() << "failed to infer lvlToDim from dimToLvl";
878 if (lvlToDim && (inferRes != lvlToDim))
879 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
880 if (dimRank > lvlRank)
881 return emitError() << "unexpected dimToLvl mapping from " << dimRank
882 << " to " << lvlRank;
883 }
884 if (!dimSlices.empty()) {
885 if (dimSlices.size() != dimRank)
886 return emitError()
887 << "dimension-rank mismatch between dimSlices and dimToLvl: "
888 << dimSlices.size() << " != " << dimRank;
889 // Compiler support for `dimSlices` currently requires that the two
890 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
891 if (dimRank != lvlRank)
892 return emitError()
893 << "dimSlices expected dimension-rank to match level-rank: "
894 << dimRank << " != " << lvlRank;
895 }
896 return success();
897}
898
899LogicalResult SparseTensorEncodingAttr::verifyEncoding(
900 ArrayRef<Size> dimShape, Type elementType,
901 function_ref<InFlightDiagnostic()> emitError) const {
902 // Check structural integrity. In particular, this ensures that the
903 // level-rank is coherent across all the fields.
904 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
905 getPosWidth(), getCrdWidth(), getExplicitVal(),
906 getImplicitVal(), getDimSlices())))
907 return failure();
908 // Check integrity with tensor type specifics. In particular, we
909 // need only check that the dimension-rank of the tensor agrees with
910 // the dimension-rank of the encoding.
911 const Dimension dimRank = dimShape.size();
912 if (dimRank == 0)
913 return emitError() << "expected non-scalar sparse tensor";
914 if (getDimRank() != dimRank)
915 return emitError()
916 << "dimension-rank mismatch between encoding and tensor shape: "
917 << getDimRank() << " != " << dimRank;
918 if (auto expVal = getExplicitVal()) {
919 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
920 if (attrType != elementType) {
921 return emitError() << "explicit value type mismatch between encoding and "
922 << "tensor element type: " << attrType
923 << " != " << elementType;
924 }
925 }
926 if (auto impVal = getImplicitVal()) {
927 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
928 if (attrType != elementType) {
929 return emitError() << "implicit value type mismatch between encoding and "
930 << "tensor element type: " << attrType
931 << " != " << elementType;
932 }
933 // Currently, we only support zero as the implicit value.
934 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
935 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
936 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
937 if ((impFVal && impFVal.getValue().isNonZero()) ||
938 (impIntVal && !impIntVal.getValue().isZero()) ||
939 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
940 impComplexVal.getReal().isNonZero()))) {
941 return emitError() << "implicit value must be zero";
942 }
943 }
944 return success();
945}
946
947Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
948 SmallVector<COOSegment> coo = getCOOSegments();
949 assert(coo.size() == 1 || coo.empty());
950 if (!coo.empty() && coo.front().isAoS()) {
951 return coo.front().lvlRange.first;
952 }
953 return getLvlRank();
954}
955
956SmallVector<COOSegment>
957mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
958 SmallVector<COOSegment> ret;
959 if (getLvlRank() <= 1)
960 return ret;
961
962 ArrayRef<LevelType> lts = getLvlTypes();
963 Level l = 0;
964 while (l < getLvlRank()) {
965 auto lt = lts[l];
967 auto cur = lts.begin() + l;
968 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
969 return !lt.isa<LevelFormat::Singleton>();
970 });
971 unsigned cooLen = std::distance(cur, end);
972 if (cooLen > 1) {
973 // To support mixed SoA/AoS COO, we should break the segment when the
974 // storage scheme changes, for now we faithfully assume that all
975 // consecutive singleton levels have the same storage format as verified
976 // STEA.
977 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
978 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
979 }
980 l += cooLen;
981 } else {
982 l++;
983 }
984 }
985 return ret;
986}
987
988//===----------------------------------------------------------------------===//
989// SparseTensorType Methods.
990//===----------------------------------------------------------------------===//
991
993 bool isUnique) const {
994 if (!hasEncoding())
995 return false;
996 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
997 return false;
998 for (Level l = startLvl + 1; l < lvlRank; ++l)
999 if (!isSingletonLvl(l))
1000 return false;
1001 // If isUnique is true, then make sure that the last level is unique,
1002 // that is, when lvlRank == 1, the only compressed level is unique,
1003 // and when lvlRank > 1, the last singleton is unique.
1004 return !isUnique || isUniqueLvl(lvlRank - 1);
1005}
1006
1007RankedTensorType
1009 SmallVector<LevelType> lvlTypes;
1010 lvlTypes.reserve(lvlRank);
1011 // A non-unique compressed level at beginning (unless this is
1012 // also the last level, then it is unique).
1013 lvlTypes.push_back(
1014 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1015 if (lvlRank > 1) {
1016 // Followed by n-2 non-unique singleton levels.
1017 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1018 *buildLevelType(LevelFormat::Singleton, ordered, false));
1019 // Ends by a unique singleton level.
1020 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1021 }
1022 auto enc = SparseTensorEncodingAttr::get(
1023 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1025 return RankedTensorType::get(getDimShape(), getElementType(), enc);
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// Convenience Methods.
1030//===----------------------------------------------------------------------===//
1031
1032SparseTensorEncodingAttr
1034 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1035 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1036 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1037 return mdtp.getEncoding();
1038 return nullptr;
1039}
1040
1042 MLIRContext *context) {
1043 auto map = static_cast<AffineMap>(dimToLvl);
1044 AffineMap lvlToDim;
1045 // Return an empty lvlToDim when inference is not successful.
1046 if (!map || map.getNumSymbols() != 0) {
1047 lvlToDim = AffineMap();
1048 } else if (map.isPermutation()) {
1049 lvlToDim = inversePermutation(map);
1050 } else if (isBlockSparsity(map)) {
1051 lvlToDim = inverseBlockSparsity(map, context);
1052 }
1053 return lvlToDim;
1054}
1055
1057 MLIRContext *context) {
1058 SmallVector<AffineExpr> lvlExprs;
1059 auto numLvls = dimToLvl.getNumResults();
1060 lvlExprs.reserve(numLvls);
1061 // lvlExprComponents stores information of the floordiv and mod operations
1062 // applied to the same dimension, so as to build the lvlToDim map.
1063 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1064 for (unsigned i = 0, n = numLvls; i < n; i++) {
1065 auto result = dimToLvl.getResult(i);
1066 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1067 if (result.getKind() == AffineExprKind::FloorDiv) {
1068 // Position of the dimension in dimToLvl.
1069 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1070 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1071 "expected only one floordiv for each dimension");
1072 SmallVector<AffineExpr, 3> components;
1073 // Level variable for floordiv.
1074 components.push_back(getAffineDimExpr(i, context));
1075 // Multiplier.
1076 components.push_back(binOp.getRHS());
1077 // Map key is the position of the dimension.
1078 lvlExprComponents[pos] = components;
1079 } else if (result.getKind() == AffineExprKind::Mod) {
1080 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1081 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1082 "expected floordiv before mod");
1083 // Add level variable for mod to the same vector
1084 // of the corresponding floordiv.
1085 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1086 } else {
1087 assert(false && "expected floordiv or mod");
1088 }
1089 } else {
1090 lvlExprs.push_back(getAffineDimExpr(i, context));
1091 }
1092 }
1093 // Build lvlExprs from lvlExprComponents.
1094 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1095 // would be [il, 2, ii]. It could be used to build the AffineExpr
1096 // i = il * 2 + ii in lvlToDim.
1097 for (auto &components : lvlExprComponents) {
1098 assert(components.second.size() == 3 &&
1099 "expected 3 components to build lvlExprs");
1100 auto mulOp = getAffineBinaryOpExpr(
1101 AffineExprKind::Mul, components.second[0], components.second[1]);
1102 auto addOp =
1103 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1104 lvlExprs.push_back(addOp);
1105 }
1106 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1107}
1108
1110 assert(isBlockSparsity(dimToLvl) &&
1111 "expected dimToLvl to be block sparsity for calling getBlockSize");
1112 SmallVector<unsigned> blockSize;
1113 for (auto result : dimToLvl.getResults()) {
1114 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1115 if (result.getKind() == AffineExprKind::Mod) {
1116 blockSize.push_back(
1117 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1118 }
1119 } else {
1120 blockSize.push_back(0);
1121 }
1122 }
1123 return blockSize;
1124}
1125
1127 if (!dimToLvl)
1128 return false;
1129 std::map<unsigned, int64_t> coeffientMap;
1130 bool hasBlock = false;
1131 for (auto result : dimToLvl.getResults()) {
1132 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1133 // Check for "dim op const".
1134 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1135 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1136 if (!dimOp || !conOp || conOp.getValue() <= 0)
1137 return false;
1138 // Inspect "dim / const" or "dim % const".
1139 auto pos = dimOp.getPosition();
1140 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1141 // Expect only one floordiv for each dimension.
1142 auto [it, inserted] = coeffientMap.try_emplace(pos);
1143 if (!inserted)
1144 return false;
1145 // Record coefficient of the floordiv.
1146 it->second = conOp.getValue();
1147 } else if (binOp.getKind() == AffineExprKind::Mod) {
1148 // Expect floordiv before mod.
1149 auto it = coeffientMap.find(pos);
1150 if (it == coeffientMap.end())
1151 return false;
1152 // Expect mod to have the same coefficient as floordiv.
1153 if (conOp.getValue() != it->second)
1154 return false;
1155 hasBlock = true;
1156 } else {
1157 return false;
1158 }
1159 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1160 auto pos = dimOp.getPosition();
1161 // Expect dim to be unset.
1162 if (!coeffientMap.try_emplace(pos, 0).second)
1163 return false;
1164 } else {
1165 return false;
1166 }
1167 }
1168 return hasBlock;
1169}
1170
1172 auto hasNonIdentityMap = [](Value v) {
1173 auto stt = tryGetSparseTensorType(v);
1174 return stt && !stt->isIdentity();
1175 };
1176
1177 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1178 llvm::any_of(op->getResults(), hasNonIdentityMap);
1179}
1180
1181Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1182 if (enc) {
1183 assert(enc.isPermutation() && "Non permutation map not supported");
1184 if (const auto dimToLvl = enc.getDimToLvl())
1185 return dimToLvl.getDimPosition(l);
1186 }
1187 return l;
1188}
1189
1190Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1191 if (enc) {
1192 assert(enc.isPermutation() && "Non permutation map not supported");
1193 if (const auto lvlToDim = enc.getLvlToDim())
1194 return lvlToDim.getDimPosition(d);
1195 }
1196 return d;
1197}
1198
1199/// We normalized sparse tensor encoding attribute by always using
1200/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1201/// as other variants) lead to the same storage specifier type, and stripping
1202/// irrelevant fields that do not alter the sparse tensor memory layout.
1203static SparseTensorEncodingAttr
1204getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1206 for (auto lt : enc.getLvlTypes())
1207 lts.push_back(lt.stripStorageIrrelevantProperties());
1208
1209 return SparseTensorEncodingAttr::get(
1210 enc.getContext(), lts,
1211 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1212 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1213 // Always use `index` for memSize and lvlSize instead of reusing
1214 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1215 // value for different bitwidth, it also avoids casting between index and
1216 // integer (returned by DimOp)
1217 0, 0,
1218 Attribute(), // explicitVal (irrelevant to storage specifier)
1219 Attribute(), // implicitVal (irrelevant to storage specifier)
1220 enc.getDimSlices());
1221}
1222
1223StorageSpecifierType
1224StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1225 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1226}
1227
1228StorageSpecifierType
1229StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1230 MLIRContext *ctx,
1231 SparseTensorEncodingAttr encoding) {
1232 return Base::getChecked(emitError, ctx,
1234}
1235
1236//===----------------------------------------------------------------------===//
1237// SparseTensorDialect Operations.
1238//===----------------------------------------------------------------------===//
1239
1240static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1241 return success(lvl < getSparseTensorType(tensor).getLvlRank());
1242}
1243
1244static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1245 const Type etp = getMemRefType(mem).getElementType();
1246 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1247}
1248
1249static LogicalResult verifySparsifierGetterSetter(
1250 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1252 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1253 return op->emitError(
1254 "redundant level argument for querying value memory size");
1255 }
1256
1257 const auto enc = md.getType().getEncoding();
1258 const Level lvlRank = enc.getLvlRank();
1259
1260 if (mdKind == StorageSpecifierKind::DimOffset ||
1261 mdKind == StorageSpecifierKind::DimStride)
1262 if (!enc.isSlice())
1263 return op->emitError("requested slice data on non-slice tensor");
1264
1265 if (mdKind != StorageSpecifierKind::ValMemSize) {
1266 if (!lvl)
1267 return op->emitError("missing level argument");
1268
1269 const Level l = lvl.value();
1270 if (l >= lvlRank)
1271 return op->emitError("requested level is out of bounds");
1272
1273 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1274 return op->emitError(
1275 "requested position memory size on a singleton level");
1276 }
1277 return success();
1278}
1279
1281 switch (kind) {
1283 return stt.getCrdType();
1285 return stt.getPosType();
1287 return stt.getElementType();
1289 return nullptr;
1290 }
1291 llvm_unreachable("Unrecognizable FieldKind");
1292}
1293
1294static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1295 SparseTensorType stt,
1296 RankedTensorType valTp,
1297 TypeRange lvlTps) {
1298 if (requiresStaticShape && !stt.hasStaticDimShape())
1299 return op->emitError("the sparse-tensor must have static shape");
1300 if (!stt.hasEncoding())
1301 return op->emitError("the sparse-tensor must have an encoding attribute");
1302
1303 // Verifies the trailing COO.
1304 Level cooStartLvl = stt.getAoSCOOStart();
1305 if (cooStartLvl < stt.getLvlRank()) {
1306 // We only supports trailing COO for now, must be the last input.
1307 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1308 // The coordinates should be in shape of <? x rank>
1309 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1310 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1311 return op->emitError("input/output trailing COO level-ranks don't match");
1312 }
1313 }
1314
1315 // Verifies that all types match.
1316 StorageLayout layout(stt.getEncoding());
1317 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1318 return op->emitError("inconsistent number of fields between input/output");
1319
1320 unsigned idx = 0;
1321 bool misMatch = false;
1322 layout.foreachField([&idx, &misMatch, stt, valTp,
1323 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1324 Level lvl, LevelType lt) -> bool {
1326 return true;
1327
1328 Type inputTp = nullptr;
1329 if (fKind == SparseTensorFieldKind::ValMemRef) {
1330 inputTp = valTp;
1331 } else {
1332 assert(fid == idx && stt.getLvlType(lvl) == lt);
1333 inputTp = lvlTps[idx++];
1334 }
1335 // The input element type and expected element type should match.
1336 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1337 Type expElemTp = getFieldElemType(stt, fKind);
1338 if (inpElemTp != expElemTp) {
1339 misMatch = true;
1340 return false; // to terminate the iteration
1341 }
1342 return true;
1343 });
1344
1345 if (misMatch)
1346 return op->emitError("input/output element-types don't match");
1347 return success();
1348}
1349
1350LogicalResult AssembleOp::verify() {
1351 RankedTensorType valuesTp = getValues().getType();
1352 const auto lvlsTp = getLevels().getTypes();
1353 const auto resTp = getSparseTensorType(getResult());
1354 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1355}
1356
1357LogicalResult DisassembleOp::verify() {
1358 if (getOutValues().getType() != getRetValues().getType())
1359 return emitError("output values and return value type mismatch");
1360
1361 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1362 if (ot.getType() != rt.getType())
1363 return emitError("output levels and return levels type mismatch");
1364
1365 RankedTensorType valuesTp = getRetValues().getType();
1366 const auto lvlsTp = getRetLevels().getTypes();
1367 const auto srcTp = getSparseTensorType(getTensor());
1368 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1369}
1370
1371LogicalResult ConvertOp::verify() {
1372 RankedTensorType tp1 = getSource().getType();
1373 RankedTensorType tp2 = getDest().getType();
1374 if (tp1.getRank() != tp2.getRank())
1375 return emitError("unexpected conversion mismatch in rank");
1376 auto dstEnc =
1377 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1378 if (dstEnc && dstEnc.isSlice())
1379 return emitError("cannot convert to a sparse tensor slice");
1380
1381 auto shape1 = tp1.getShape();
1382 auto shape2 = tp2.getShape();
1383 // Accept size matches between the source and the destination type
1384 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1385 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1386 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1387 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1388 return emitError("unexpected conversion mismatch in dimension ") << d;
1389 return success();
1390}
1391
1392OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1393 if (getType() == getSource().getType())
1394 return getSource();
1395 return {};
1396}
1397
1398bool ConvertOp::needsExtraSort() {
1399 SparseTensorType srcStt = getSparseTensorType(getSource());
1400 SparseTensorType dstStt = getSparseTensorType(getDest());
1401
1402 // We do not need an extra sort when returning unordered sparse tensors or
1403 // dense tensor since dense tensor support random access.
1404 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1405 return false;
1406
1407 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1408 srcStt.hasSameDimToLvl(dstStt)) {
1409 return false;
1410 }
1411
1412 // Source and dest tensors are ordered in different ways. We only do direct
1413 // dense to sparse conversion when the dense input is defined by a sparse
1414 // constant. Note that we can theoretically always directly convert from dense
1415 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1416 // performance.
1417 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1418 if (isa<SparseElementsAttr>(constOp.getValue()))
1419 return false;
1420
1421 return true;
1422}
1423
1424LogicalResult CrdTranslateOp::verify() {
1425 uint64_t inRank = getEncoder().getLvlRank();
1426 uint64_t outRank = getEncoder().getDimRank();
1427
1428 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1429 std::swap(inRank, outRank);
1430
1431 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1432 return emitError("Coordinate rank mismatch with encoding");
1433
1434 return success();
1435}
1436
1437LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1438 SmallVectorImpl<OpFoldResult> &results) {
1439 if (getEncoder().isIdentity()) {
1440 results.assign(getInCrds().begin(), getInCrds().end());
1441 return success();
1442 }
1443 if (getEncoder().isPermutation()) {
1444 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1445 ? getEncoder().getDimToLvl()
1446 : getEncoder().getLvlToDim();
1447 for (AffineExpr exp : perm.getResults())
1448 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1449 return success();
1450 }
1451
1452 // Fuse dim2lvl/lvl2dim pairs.
1453 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1454 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1455 return v.getDefiningOp() == def;
1456 });
1457 if (!sameDef)
1458 return failure();
1459
1460 bool oppositeDir = def.getDirection() != getDirection();
1461 bool sameOracle =
1462 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1463 bool sameCount = def.getNumResults() == getInCrds().size();
1464 if (!oppositeDir || !sameOracle || !sameCount)
1465 return failure();
1466
1467 // The definition produces the coordinates in the same order as the input
1468 // coordinates.
1469 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1470 [](auto valuePair) {
1471 auto [lhs, rhs] = valuePair;
1472 return lhs == rhs;
1473 });
1474
1475 if (!sameOrder)
1476 return failure();
1477 // l1 = dim2lvl (lvl2dim l0)
1478 // ==> l0
1479 results.append(def.getInCrds().begin(), def.getInCrds().end());
1480 return success();
1481}
1482
1483void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1484 int64_t index) {
1485 Value val = arith::ConstantIndexOp::create(builder, state.location, index);
1486 return build(builder, state, source, val);
1487}
1488
1489LogicalResult LvlOp::verify() {
1490 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1491 auto stt = getSparseTensorType(getSource());
1492 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1493 return emitError(
1494 "Level index exceeds the rank of the input sparse tensor");
1495 }
1496 return success();
1497}
1498
1499std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1500 return getConstantIntValue(getIndex());
1501}
1502
1503Speculation::Speculatability LvlOp::getSpeculatability() {
1504 auto constantIndex = getConstantLvlIndex();
1505 if (!constantIndex)
1507
1508 assert(constantIndex <
1509 cast<RankedTensorType>(getSource().getType()).getRank());
1511}
1512
1513OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1514 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1515 if (!lvlIndex)
1516 return {};
1517
1518 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1519 auto stt = getSparseTensorType(getSource());
1520 if (lvl >= stt.getLvlRank()) {
1521 // Follows the same convention used by tensor.dim operation. Out of bound
1522 // indices produce undefined behavior but are still valid IR. Don't choke on
1523 // them.
1524 return {};
1525 }
1526
1527 // Helper lambda to build an IndexAttr.
1528 auto getIndexAttr = [this](int64_t lvlSz) {
1529 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1530 };
1531
1532 SmallVector<Size> lvlShape = stt.getLvlShape();
1533 if (ShapedType::isStatic(lvlShape[lvl]))
1534 return getIndexAttr(lvlShape[lvl]);
1535
1536 return {};
1537}
1538
1539void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1540 SparseTensorEncodingAttr dstEnc, Value source) {
1541 auto srcStt = getSparseTensorType(source);
1542 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1543 SmallVector<int64_t> dstDimShape =
1544 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1545 auto dstTp =
1546 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1547 return build(odsBuilder, odsState, dstTp, source);
1548}
1549
1550LogicalResult ReinterpretMapOp::verify() {
1551 auto srcStt = getSparseTensorType(getSource());
1552 auto dstStt = getSparseTensorType(getDest());
1553 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1554 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1555
1556 if (srcLvlTps.size() != dstLvlTps.size())
1557 return emitError("Level rank mismatch between source/dest tensors");
1558
1559 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1560 if (srcLvlTp != dstLvlTp)
1561 return emitError("Level type mismatch between source/dest tensors");
1562
1563 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1564 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1565 return emitError("Crd/Pos width mismatch between source/dest tensors");
1566 }
1567
1568 if (srcStt.getElementType() != dstStt.getElementType())
1569 return emitError("Element type mismatch between source/dest tensors");
1570
1571 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1572 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1573 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1574 if (srcLvlSz != dstLvlSz) {
1575 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1576 // compatible to <3x4>? For now, we require all the level sizes to be
1577 // *exactly* matched for simplicity.
1578 return emitError("Level size mismatch between source/dest tensors");
1579 }
1580 }
1581
1582 return success();
1583}
1584
1585OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1586 if (getSource().getType() == getDest().getType())
1587 return getSource();
1588
1589 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1590 // A -> B, B -> A ==> A
1591 if (def.getSource().getType() == getDest().getType())
1592 return def.getSource();
1593 }
1594 return {};
1595}
1596
1597template <typename ToBufferOp>
1598static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1599 OpaqueProperties prop,
1600 RegionRange region,
1602 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1603 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1604 Type elemTp = nullptr;
1605 bool withStride = false;
1606 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1607 elemTp = stt.getPosType();
1608 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1609 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1610 elemTp = stt.getCrdType();
1611 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1612 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1613 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1614 elemTp = stt.getElementType();
1615 }
1616
1617 assert(elemTp && "unhandled operation.");
1618 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1619 bufShape.push_back(ShapedType::kDynamic);
1620
1621 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1622 stt.getContext(), ShapedType::kDynamic,
1623 {ShapedType::kDynamic})
1624 : StridedLayoutAttr();
1625 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1626 return success();
1627}
1628
1629LogicalResult ToPositionsOp::verify() {
1630 auto stt = getSparseTensorType(getTensor());
1631 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1632 return emitError("requested level is out of bounds");
1633 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1634 return emitError("unexpected type for positions");
1635 return success();
1636}
1637
1638LogicalResult
1639ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1640 ValueRange ops, DictionaryAttr attr,
1641 OpaqueProperties prop, RegionRange region,
1642 SmallVectorImpl<mlir::Type> &ret) {
1643 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1644}
1645
1646LogicalResult ToCoordinatesOp::verify() {
1647 auto stt = getSparseTensorType(getTensor());
1648 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1649 return emitError("requested level is out of bounds");
1650 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1651 return emitError("unexpected type for coordinates");
1652 return success();
1653}
1654
1655LogicalResult
1656ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1657 ValueRange ops, DictionaryAttr attr,
1658 OpaqueProperties prop, RegionRange region,
1659 SmallVectorImpl<mlir::Type> &ret) {
1660 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1661}
1662
1663LogicalResult ToCoordinatesBufferOp::verify() {
1664 auto stt = getSparseTensorType(getTensor());
1665 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1666 return emitError("expected sparse tensor with a COO region");
1667 return success();
1668}
1669
1670LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1671 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1672 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1673 SmallVectorImpl<mlir::Type> &ret) {
1674 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1675 ret);
1676}
1677
1678LogicalResult ToValuesOp::verify() {
1679 auto stt = getSparseTensorType(getTensor());
1680 auto mtp = getMemRefType(getResult());
1681 if (stt.getElementType() != mtp.getElementType())
1682 return emitError("unexpected mismatch in element types");
1683 return success();
1684}
1685
1686LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1687 std::optional<Location> loc,
1688 ValueRange ops, DictionaryAttr attr,
1689 OpaqueProperties prop,
1690 RegionRange region,
1691 SmallVectorImpl<mlir::Type> &ret) {
1692 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1693}
1694
1695LogicalResult ToSliceOffsetOp::verify() {
1696 auto rank = getSlice().getType().getRank();
1697 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1698 return emitError("requested dimension out of bound");
1699 return success();
1700}
1701
1702LogicalResult ToSliceStrideOp::verify() {
1703 auto rank = getSlice().getType().getRank();
1704 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1705 return emitError("requested dimension out of bound");
1706 return success();
1707}
1708
1709LogicalResult GetStorageSpecifierOp::verify() {
1710 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1711 getSpecifier(), getOperation());
1712}
1713
1714template <typename SpecifierOp>
1715static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1716 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1717}
1718
1719OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1720 const StorageSpecifierKind kind = getSpecifierKind();
1721 const auto lvl = getLevel();
1722 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1723 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1724 return op.getValue();
1725 return {};
1726}
1727
1728LogicalResult SetStorageSpecifierOp::verify() {
1729 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1730 getSpecifier(), getOperation());
1731}
1732
1733template <class T>
1734static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1735 const char *regionName,
1736 TypeRange inputTypes, Type outputType) {
1737 unsigned numArgs = region.getNumArguments();
1738 unsigned expectedNum = inputTypes.size();
1739 if (numArgs != expectedNum)
1740 return op->emitError() << regionName << " region must have exactly "
1741 << expectedNum << " arguments";
1742
1743 for (unsigned i = 0; i < numArgs; i++) {
1744 Type typ = region.getArgument(i).getType();
1745 if (typ != inputTypes[i])
1746 return op->emitError() << regionName << " region argument " << (i + 1)
1747 << " type mismatch";
1748 }
1749 Operation *term = region.front().getTerminator();
1750 YieldOp yield = dyn_cast<YieldOp>(term);
1751 if (!yield)
1752 return op->emitError() << regionName
1753 << " region must end with sparse_tensor.yield";
1754 if (!yield.hasSingleResult() ||
1755 yield.getSingleResult().getType() != outputType)
1756 return op->emitError() << regionName << " region yield type mismatch";
1757
1758 return success();
1759}
1760
1761LogicalResult BinaryOp::verify() {
1762 NamedAttrList attrs = (*this)->getAttrs();
1763 Type leftType = getX().getType();
1764 Type rightType = getY().getType();
1765 Type outputType = getOutput().getType();
1766 Region &overlap = getOverlapRegion();
1767 Region &left = getLeftRegion();
1768 Region &right = getRightRegion();
1769
1770 // Check correct number of block arguments and return type for each
1771 // non-empty region.
1772 if (!overlap.empty()) {
1773 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1774 TypeRange{leftType, rightType}, outputType)))
1775 return failure();
1776 }
1777 if (!left.empty()) {
1778 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1779 outputType)))
1780 return failure();
1781 } else if (getLeftIdentity()) {
1782 if (leftType != outputType)
1783 return emitError("left=identity requires first argument to have the same "
1784 "type as the output");
1785 }
1786 if (!right.empty()) {
1787 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1788 outputType)))
1789 return failure();
1790 } else if (getRightIdentity()) {
1791 if (rightType != outputType)
1792 return emitError("right=identity requires second argument to have the "
1793 "same type as the output");
1794 }
1795 return success();
1796}
1797
1798LogicalResult UnaryOp::verify() {
1799 Type inputType = getX().getType();
1800 Type outputType = getOutput().getType();
1801
1802 // Check correct number of block arguments and return type for each
1803 // non-empty region.
1804 Region &present = getPresentRegion();
1805 if (!present.empty()) {
1806 if (failed(verifyNumBlockArgs(this, present, "present",
1807 TypeRange{inputType}, outputType)))
1808 return failure();
1809 }
1810 Region &absent = getAbsentRegion();
1811 if (!absent.empty()) {
1812 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1813 outputType)))
1814 return failure();
1815 // Absent branch can only yield invariant values.
1816 Block *absentBlock = &absent.front();
1817 Block *parent = getOperation()->getBlock();
1818 Value absentVal =
1819 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1820 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1821 if (arg.getOwner() == parent)
1822 return emitError("absent region cannot yield linalg argument");
1823 } else if (Operation *def = absentVal.getDefiningOp()) {
1824 if (!isa<arith::ConstantOp>(def) &&
1825 (def->getBlock() == absentBlock || def->getBlock() == parent))
1826 return emitError("absent region cannot yield locally computed value");
1827 }
1828 }
1829 return success();
1830}
1831
1832bool ConcatenateOp::needsExtraSort() {
1833 SparseTensorType dstStt = getSparseTensorType(*this);
1834 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1835 return false;
1836
1837 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1838 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1839 });
1840 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1841 // in all input/output buffers, and all input/output buffers have the same
1842 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1843 // CSC matrices along column).
1844 bool directLowerable =
1845 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1846 return !directLowerable;
1847}
1848
1849LogicalResult ConcatenateOp::verify() {
1850 const auto dstTp = getSparseTensorType(*this);
1851 const Dimension concatDim = getDimension();
1852 const Dimension dimRank = dstTp.getDimRank();
1853
1854 if (getInputs().size() <= 1)
1855 return emitError("Need at least two tensors to concatenate.");
1856
1857 if (concatDim >= dimRank)
1858 return emitError(llvm::formatv(
1859 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1860 concatDim, dimRank));
1861
1862 for (const auto &it : llvm::enumerate(getInputs())) {
1863 const auto i = it.index();
1864 const auto srcTp = getSparseTensorType(it.value());
1865 if (srcTp.hasDynamicDimShape())
1866 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1867 const Dimension srcDimRank = srcTp.getDimRank();
1868 if (srcDimRank != dimRank)
1869 return emitError(
1870 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1871 "from the output tensor (rank={2}).",
1872 i, srcDimRank, dimRank));
1873 }
1874
1875 for (Dimension d = 0; d < dimRank; d++) {
1876 const Size dstSh = dstTp.getDimShape()[d];
1877 if (d == concatDim) {
1878 if (ShapedType::isStatic(dstSh)) {
1879 // If we reach here, then all inputs have static shapes. So we
1880 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1881 // to avoid redundant assertions in the loop.
1882 Size sumSz = 0;
1883 for (const auto src : getInputs())
1884 sumSz += getSparseTensorType(src).getDimShape()[d];
1885 // If all dimension are statically known, the sum of all the input
1886 // dimensions should be equal to the output dimension.
1887 if (sumSz != dstSh)
1888 return emitError(
1889 "The concatenation dimension of the output tensor should be the "
1890 "sum of all the concatenation dimensions of the input tensors.");
1891 }
1892 } else {
1893 Size prev = dstSh;
1894 for (const auto src : getInputs()) {
1895 const auto sh = getSparseTensorType(src).getDimShape()[d];
1896 if (ShapedType::isStatic(prev) && sh != prev)
1897 return emitError("All dimensions (expect for the concatenating one) "
1898 "should be equal.");
1899 prev = sh;
1900 }
1901 }
1902 }
1903
1904 return success();
1905}
1906
1907void PushBackOp::build(OpBuilder &builder, OperationState &result,
1908 Value curSize, Value inBuffer, Value value) {
1909 build(builder, result, curSize, inBuffer, value, Value());
1910}
1911
1912LogicalResult PushBackOp::verify() {
1913 if (Value n = getN()) {
1914 std::optional<int64_t> nValue = getConstantIntValue(n);
1915 if (nValue && nValue.value() < 1)
1916 return emitOpError("n must be not less than 1");
1917 }
1918 return success();
1919}
1920
1921LogicalResult CompressOp::verify() {
1922 const auto stt = getSparseTensorType(getTensor());
1923 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1924 return emitOpError("incorrect number of coordinates");
1925 return success();
1926}
1927
1928void ForeachOp::build(
1929 OpBuilder &builder, OperationState &result, Value tensor,
1930 ValueRange initArgs, AffineMapAttr order,
1931 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1932 bodyBuilder) {
1933 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1934 // Builds foreach body.
1935 if (!bodyBuilder)
1936 return;
1937 const auto stt = getSparseTensorType(tensor);
1938 const Dimension dimRank = stt.getDimRank();
1939
1940 // Starts with `dimRank`-many coordinates.
1941 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1942 // Followed by one value.
1943 blockArgTypes.push_back(stt.getElementType());
1944 // Followed by the reduction variables.
1945 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1946
1947 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1948
1949 OpBuilder::InsertionGuard guard(builder);
1950 auto &region = *result.regions.front();
1951 Block *bodyBlock =
1952 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1953 bodyBuilder(builder, result.location,
1954 bodyBlock->getArguments().slice(0, dimRank),
1955 bodyBlock->getArguments()[dimRank],
1956 bodyBlock->getArguments().drop_front(dimRank + 1));
1957}
1958
1959LogicalResult ForeachOp::verify() {
1960 const auto t = getSparseTensorType(getTensor());
1961 const Dimension dimRank = t.getDimRank();
1962 const auto args = getBody()->getArguments();
1963
1964 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1965 return emitError("Level traverse order does not match tensor's level rank");
1966
1967 if (dimRank + 1 + getInitArgs().size() != args.size())
1968 return emitError("Unmatched number of arguments in the block");
1969
1970 if (getNumResults() != getInitArgs().size())
1971 return emitError("Mismatch in number of init arguments and results");
1972
1973 if (getResultTypes() != getInitArgs().getTypes())
1974 return emitError("Mismatch in types of init arguments and results");
1975
1976 // Cannot mark this const, because the getters aren't.
1977 auto yield = cast<YieldOp>(getBody()->getTerminator());
1978 if (yield.getNumOperands() != getNumResults() ||
1979 yield.getOperands().getTypes() != getResultTypes())
1980 return emitError("Mismatch in types of yield values and results");
1981
1982 const auto iTp = IndexType::get(getContext());
1983 for (Dimension d = 0; d < dimRank; d++)
1984 if (args[d].getType() != iTp)
1985 return emitError(
1986 llvm::formatv("Expecting Index type for argument at index {0}", d));
1987
1988 const auto elemTp = t.getElementType();
1989 const auto valueTp = args[dimRank].getType();
1990 if (elemTp != valueTp)
1991 return emitError(
1992 llvm::formatv("Unmatched element type between input tensor and "
1993 "block argument, expected:{0}, got: {1}",
1994 elemTp, valueTp));
1995 return success();
1996}
1997
1998OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1999 if (getSparseTensorEncoding(getInputCoo().getType()) ==
2000 getSparseTensorEncoding(getResultCoo().getType()))
2001 return getInputCoo();
2002
2003 return {};
2004}
2005
2006LogicalResult ReorderCOOOp::verify() {
2007 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2008 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2009
2010 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2011 return emitError("Expected COO sparse tensors only");
2012
2013 if (!srcStt.hasSameDimToLvl(dstStt))
2014 return emitError("Unmatched dim2lvl map between input and result COO");
2015
2016 if (srcStt.getPosType() != dstStt.getPosType() ||
2017 srcStt.getCrdType() != dstStt.getCrdType() ||
2018 srcStt.getElementType() != dstStt.getElementType())
2019 return emitError("Unmatched storage format between input and result COO");
2020
2021 return success();
2022}
2023
2024LogicalResult ReduceOp::verify() {
2025 Type inputType = getX().getType();
2026 Region &formula = getRegion();
2027 return verifyNumBlockArgs(this, formula, "reduce",
2028 TypeRange{inputType, inputType}, inputType);
2029}
2030
2031LogicalResult SelectOp::verify() {
2032 Builder b(getContext());
2033 Type inputType = getX().getType();
2034 Type boolType = b.getI1Type();
2035 Region &formula = getRegion();
2036 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2037 boolType);
2038}
2039
2040LogicalResult SortOp::verify() {
2041 AffineMap xPerm = getPermMap();
2042 uint64_t nx = xPerm.getNumDims();
2043 if (nx < 1)
2044 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2045
2046 if (!xPerm.isPermutation())
2047 return emitError(
2048 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2049
2050 // We can't check the size of the buffers when n or buffer dimensions aren't
2051 // compile-time constants.
2052 std::optional<int64_t> cn = getConstantIntValue(getN());
2053 if (!cn)
2054 return success();
2055
2056 // Verify dimensions.
2057 const auto checkDim = [&](Value v, Size minSize,
2058 const char *message) -> LogicalResult {
2059 const Size sh = getMemRefType(v).getShape()[0];
2060 if (ShapedType::isStatic(sh) && sh < minSize)
2061 return emitError(
2062 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2063 return success();
2064 };
2065 uint64_t n = cn.value();
2066 uint64_t ny = 0;
2067 if (auto nyAttr = getNyAttr())
2068 ny = nyAttr.getInt();
2069 if (failed(checkDim(getXy(), n * (nx + ny),
2070 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2071 return failure();
2072 for (Value opnd : getYs())
2073 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2074 return failure();
2075
2076 return success();
2077}
2078
2079//===----------------------------------------------------------------------===//
2080// Sparse Tensor Iteration Operations.
2081//===----------------------------------------------------------------------===//
2082
2083IterSpaceType IteratorType::getIterSpaceType() const {
2084 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2085 getHiLvl());
2086}
2087
2088IteratorType IterSpaceType::getIteratorType() const {
2089 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2090}
2091
2092/// Parses a level range in the form "$lo `to` $hi"
2093/// or simply "$lo" if $hi - $lo = 1
2094static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2095 Level &lvlHi) {
2096 if (parser.parseInteger(lvlLo))
2097 return failure();
2098
2099 if (succeeded(parser.parseOptionalKeyword("to"))) {
2100 if (parser.parseInteger(lvlHi))
2101 return failure();
2102 } else {
2103 lvlHi = lvlLo + 1;
2104 }
2105
2106 if (lvlHi <= lvlLo)
2107 return parser.emitError(parser.getNameLoc(),
2108 "expect larger level upper bound than lower bound");
2109
2110 return success();
2111}
2112
2113/// Parses a level range in the form "$lo `to` $hi"
2114/// or simply "$lo" if $hi - $lo = 1
2115static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2116 IntegerAttr &lvlHiAttr) {
2117 Level lvlLo, lvlHi;
2118 if (parseLevelRange(parser, lvlLo, lvlHi))
2119 return failure();
2120
2121 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2122 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2123 return success();
2124}
2125
2126/// Prints a level range in the form "$lo `to` $hi"
2127/// or simply "$lo" if $hi - $lo = 1
2128static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2129
2130 if (lo + 1 == hi)
2131 p << lo;
2132 else
2133 p << lo << " to " << hi;
2134}
2135
2136/// Prints a level range in the form "$lo `to` $hi"
2137/// or simply "$lo" if $hi - $lo = 1
2138static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2139 IntegerAttr lvlHi) {
2140 unsigned lo = lvlLo.getValue().getZExtValue();
2141 unsigned hi = lvlHi.getValue().getZExtValue();
2142 printLevelRange(p, lo, hi);
2143}
2144
2145/// Parses a list of `optional` defined list in the form of
2146/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2147/// corresponding value is not defined (e.g., to represent an undefined
2148/// coordinate in the sparse iteration space).
2149static ParseResult parseOptionalDefinedList(
2150 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2152 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2154 unsigned cnt = 0;
2155 ParseResult crdList =
2156 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2157 if (parser.parseOptionalKeyword("_")) {
2158 if (parser.parseArgument(definedArgs.emplace_back()))
2159 return failure();
2160 definedSet.set(cnt);
2161 }
2162 cnt += 1;
2163 return success();
2164 });
2165
2166 if (cnt > maxCnt)
2167 return parser.emitError(parser.getNameLoc(),
2168 "parsed more value than expected.");
2169
2170 if (failed(crdList)) {
2171 return parser.emitError(
2172 parser.getNameLoc(),
2173 "expecting SSA value or \"_\" for level coordinates");
2174 }
2175 assert(definedArgs.size() == definedSet.count());
2176 return success();
2177}
2178
2179static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2180 Block::BlockArgListType blocksArgs,
2181 I64BitSet definedSet) {
2182 if (definedSet.empty())
2183 return;
2184
2185 for (unsigned i = 0; i < size; i++) {
2186 if (definedSet[i]) {
2187 p << blocksArgs.front();
2188 blocksArgs = blocksArgs.drop_front();
2189 } else {
2190 p << "_";
2191 }
2192 if (i != size - 1)
2193 p << ", ";
2194 }
2195 assert(blocksArgs.empty());
2196}
2197
2198static ParseResult
2201 // Parse "at(%crd0, _, ...)"
2202 I64BitSet crdUsedLvlSet;
2203 if (succeeded(parser.parseOptionalKeyword("at")) &&
2204 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2205 return failure();
2206
2207 // Always use IndexType for the coordinate.
2208 for (auto &coord : coords)
2209 coord.type = parser.getBuilder().getIndexType();
2210
2211 // Set the CrdUsedLvl bitset.
2212 state.addAttribute("crdUsedLvls",
2213 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2214 return success();
2215}
2216
2217static ParseResult
2223
2224 // Parse "%iters, ... in %spaces, ..."
2225 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2226 parser.parseOperandList(spaces))
2227 return failure();
2228
2229 if (iterators.size() != spaces.size())
2230 return parser.emitError(
2231 parser.getNameLoc(),
2232 "mismatch in number of sparse iterators and sparse spaces");
2233
2235 if (failed(parseUsedCoordList(parser, state, coords)))
2236 return failure();
2237 size_t numCrds = coords.size();
2238
2239 // Parse "iter_args(%arg = %init, ...)"
2240 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2241 if (hasIterArgs)
2242 if (parser.parseAssignmentList(blockArgs, initArgs))
2243 return failure();
2244
2245 blockArgs.append(coords);
2246
2247 SmallVector<Type> iterSpaceTps;
2248 // parse ": sparse_tensor.iter_space -> ret"
2249 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2250 return failure();
2251 if (iterSpaceTps.size() != spaces.size())
2252 return parser.emitError(parser.getNameLoc(),
2253 "mismatch in number of iteration space operands "
2254 "and iteration space types");
2255
2256 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2257 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2258 if (!spaceTp)
2259 return parser.emitError(parser.getNameLoc(),
2260 "expected sparse_tensor.iter_space type for "
2261 "iteration space operands");
2262 it.type = spaceTp.getIteratorType();
2263 }
2264
2265 if (hasIterArgs)
2266 if (parser.parseArrowTypeList(state.types))
2267 return failure();
2268
2269 // Resolves input operands.
2270 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2271 state.operands))
2272 return failure();
2273
2274 if (hasIterArgs) {
2275 // Strip off leading args that used for coordinates.
2276 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2277 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2278 return parser.emitError(
2279 parser.getNameLoc(),
2280 "mismatch in number of iteration arguments and return values");
2281 }
2282
2283 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2284 it.type = tp;
2285 if (parser.resolveOperand(init, tp, state.operands))
2286 return failure();
2287 }
2288 }
2289 return success();
2290}
2291
2292static ParseResult
2294 SmallVectorImpl<Value> &spacesVals,
2296
2297 // Parse "(%spaces, ...)"
2300 return failure();
2301
2303 if (failed(parseUsedCoordList(parser, state, coords)))
2304 return failure();
2305 size_t numCrds = coords.size();
2306
2307 // Parse "iter_args(%arg = %init, ...)"
2309 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2310 if (hasIterArgs)
2311 if (parser.parseAssignmentList(blockArgs, initArgs))
2312 return failure();
2313 blockArgs.append(coords);
2314
2315 SmallVector<Type> iterSpaceTps;
2316 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2317 if (parser.parseColon() || parser.parseLParen() ||
2318 parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2319 return failure();
2320
2321 if (iterSpaceTps.size() != spaces.size())
2322 return parser.emitError(parser.getNameLoc(),
2323 "mismatch in number of iteration space operands "
2324 "and iteration space types");
2325
2326 if (hasIterArgs)
2327 if (parser.parseArrowTypeList(state.types))
2328 return failure();
2329
2330 // Resolves input sparse iteration spaces.
2331 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2332 spacesVals))
2333 return failure();
2334 state.operands.append(spacesVals);
2335
2336 if (hasIterArgs) {
2337 // Strip off trailing args that used for coordinates.
2338 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2339 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2340 return parser.emitError(
2341 parser.getNameLoc(),
2342 "mismatch in number of iteration arguments and return values");
2343 }
2344
2345 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2346 it.type = tp;
2347 if (parser.resolveOperand(init, tp, state.operands))
2348 return failure();
2349 }
2350 }
2351 return success();
2352}
2353
2354LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2355 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2356 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2357 SmallVectorImpl<mlir::Type> &ret) {
2358
2359 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2360 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2361 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2362 adaptor.getHiLvl()));
2363 return success();
2364}
2365
2366LogicalResult ExtractIterSpaceOp::verify() {
2367 if (getLoLvl() >= getHiLvl())
2368 return emitOpError("expected smaller level low than level high");
2369
2370 TypedValue<IteratorType> pIter = getParentIter();
2371 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2372 return emitOpError(
2373 "parent iterator should be specified iff level lower bound equals 0");
2374 }
2375
2376 if (pIter) {
2377 IterSpaceType spaceTp = getExtractedSpace().getType();
2378 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2379 return emitOpError(
2380 "mismatch in parent iterator encoding and iteration space encoding.");
2381
2382 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2383 return emitOpError("parent iterator should be used to extract an "
2384 "iteration space from a consecutive level.");
2385 }
2386
2387 return success();
2388}
2389
2390LogicalResult ExtractValOp::verify() {
2391 auto stt = getSparseTensorType(getTensor());
2392 auto itTp = getIterator().getType();
2393
2394 if (stt.getEncoding() != itTp.getEncoding())
2395 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2396
2397 if (stt.getLvlRank() != itTp.getHiLvl())
2398 return emitOpError("must use last-level iterator to extract values. ");
2399
2400 return success();
2401}
2402
2403struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2405
2406 LogicalResult matchAndRewrite(IterateOp iterateOp,
2407 PatternRewriter &rewriter) const override {
2408 I64BitSet newUsedLvls(0);
2409 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2410 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2411 if (auto crd = iterateOp.getLvlCrd(i)) {
2412 if (crd->getUsers().empty())
2413 toRemove.set(crd->getArgNumber());
2414 else
2415 newUsedLvls.set(i);
2416 }
2417 }
2418
2419 // All coordinates are used.
2420 if (toRemove.none())
2421 return failure();
2422
2423 rewriter.startOpModification(iterateOp);
2424 iterateOp.setCrdUsedLvls(newUsedLvls);
2425 iterateOp.getBody()->eraseArguments(toRemove);
2426 rewriter.finalizeOpModification(iterateOp);
2427 return success();
2428 }
2429};
2430
2431void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2432 mlir::MLIRContext *context) {
2433 results.add<RemoveUnusedLvlCrds>(context);
2434}
2435
2436void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2437 Value iterSpace, ValueRange initArgs) {
2438 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2439 // All ones.
2440 I64BitSet set((1 << rank) - 1);
2441 return build(builder, odsState, iterSpace, initArgs, set);
2442}
2443
2444void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2445 Value iterSpace, ValueRange initArgs,
2446 I64BitSet crdUsedLvls) {
2447 OpBuilder::InsertionGuard guard(builder);
2448
2449 odsState.addOperands(iterSpace);
2450 odsState.addOperands(initArgs);
2451 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2452 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2453 Region *bodyRegion = odsState.addRegion();
2454 odsState.addTypes(initArgs.getTypes());
2455 Block *bodyBlock = builder.createBlock(bodyRegion);
2456
2457 // Starts with a list of user-provided loop arguments.
2458 for (Value v : initArgs)
2459 bodyBlock->addArgument(v.getType(), v.getLoc());
2460
2461 // Follows by a list of used coordinates.
2462 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2463 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2464
2465 // Ends with sparse iterator
2466 bodyBlock->addArgument(
2467 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2468 odsState.location);
2469}
2470
2471ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2472 OpAsmParser::Argument iterator;
2473 OpAsmParser::UnresolvedOperand iterSpace;
2474
2475 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2476 if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2477 return failure();
2478 if (iters.size() != 1)
2479 return parser.emitError(parser.getNameLoc(),
2480 "expected only one iterator/iteration space");
2481
2482 iterArgs.append(iters);
2483 Region *body = result.addRegion();
2484 if (parser.parseRegion(*body, iterArgs))
2485 return failure();
2486
2487 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2488
2489 // Parse the optional attribute list.
2490 if (parser.parseOptionalAttrDict(result.attributes))
2491 return failure();
2492
2493 return success();
2494}
2495
2496/// Prints the initialization list in the form of
2497/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2498/// where 'inner' values are assumed to be region arguments and 'outer' values
2499/// are regular SSA values.
2501 Block::BlockArgListType blocksArgs,
2502 ValueRange initializers,
2503 StringRef prefix = "") {
2504 assert(blocksArgs.size() == initializers.size() &&
2505 "expected same length of arguments and initializers");
2506 if (initializers.empty())
2507 return;
2508
2509 p << prefix << '(';
2510 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2511 p << std::get<0>(it) << " = " << std::get<1>(it);
2512 });
2513 p << ")";
2514}
2515
2516template <typename SparseLoopOp>
2517static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2518 if (op.getInitArgs().size() != op.getNumResults()) {
2519 return op.emitOpError(
2520 "mismatch in number of loop-carried values and defined values");
2521 }
2522 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2523 return op.emitOpError("required out-of-bound coordinates");
2524
2525 return success();
2526}
2527
2528LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2529LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2530
2531void IterateOp::print(OpAsmPrinter &p) {
2532 p << " " << getIterator() << " in " << getIterSpace();
2533 if (!getCrdUsedLvls().empty()) {
2534 p << " at(";
2535 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2536 p << ")";
2537 }
2538 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2539
2540 p << " : " << getIterSpace().getType() << " ";
2541 if (!getInitArgs().empty())
2542 p.printArrowTypeList(getInitArgs().getTypes());
2543
2544 p << " ";
2545 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2546 /*printBlockTerminators=*/!getInitArgs().empty());
2547}
2548
2549LogicalResult IterateOp::verifyRegions() {
2550 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2551 return emitOpError("mismatch in iterator and iteration space type");
2552 if (getNumRegionIterArgs() != getNumResults())
2553 return emitOpError(
2554 "mismatch in number of basic block args and defined values");
2555
2556 auto initArgs = getInitArgs();
2557 auto iterArgs = getRegionIterArgs();
2558 auto yieldVals = getYieldedValues();
2559 auto opResults = getResults();
2560 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2561 opResults.size()})) {
2562 return emitOpError() << "number mismatch between iter args and results.";
2563 }
2564
2565 for (auto [i, init, iter, yield, ret] :
2566 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2567 if (init.getType() != ret.getType())
2568 return emitOpError() << "types mismatch between " << i
2569 << "th iter operand and defined value";
2570 if (iter.getType() != ret.getType())
2571 return emitOpError() << "types mismatch between " << i
2572 << "th iter region arg and defined value";
2573 if (yield.getType() != ret.getType())
2574 return emitOpError() << "types mismatch between " << i
2575 << "th yield value and defined value";
2576 }
2577
2578 return success();
2579}
2580
2581/// OpInterfaces' methods implemented by IterateOp.
2582SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2583
2584MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2585 return getInitArgsMutable();
2586}
2587
2588Block::BlockArgListType IterateOp::getRegionIterArgs() {
2589 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2590}
2591
2592std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2593 return cast<sparse_tensor::YieldOp>(
2594 getRegion().getBlocks().front().getTerminator())
2595 .getResultsMutable();
2596}
2597
2598std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2599
2600OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2601 return getInitArgs();
2602}
2603
2604void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2605 SmallVectorImpl<RegionSuccessor> &regions) {
2606 // Both the operation itself and the region may be branching into the body
2607 // or back into the operation itself.
2608 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2609 // It is possible for loop not to enter the body.
2610 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2611}
2612
2613void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2614 ValueRange iterSpaces, ValueRange initArgs,
2615 unsigned numCases) {
2616 unsigned rank =
2617 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2618 // All ones.
2619 I64BitSet set((1 << rank) - 1);
2620 // Generates all-zero case bits (they only serve as placeholders), which are
2621 // supposed to be overriden later. We need to preallocate all the regions as
2622 // mlir::Region cannot be dynamically added later after the operation is
2623 // created.
2624 SmallVector<int64_t> caseBits(numCases, 0);
2625 ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2626 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2627 initArgs, set, cases,
2628 /*caseRegionsCount=*/numCases);
2629}
2630
2631ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2632
2633 SmallVector<Value> spaces;
2634 // The block argument list of each regions, it is arranged in the order of
2635 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2636 SmallVector<OpAsmParser::Argument> blockArgs;
2637 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2638 return failure();
2639
2640 result.addAttribute("operandSegmentSizes",
2642 {static_cast<int32_t>(spaces.size()),
2643 static_cast<int32_t>(result.types.size())}));
2644
2645 SmallVector<Attribute> cases;
2646 while (succeeded(parser.parseOptionalKeyword("case"))) {
2647 // Parse one region per case.
2648 I64BitSet definedItSet;
2649 SmallVector<OpAsmParser::Argument> definedIts;
2650 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2651 spaces.size(), OpAsmParser::Delimiter::None))
2652 return failure();
2653
2654 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2655
2656 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2657 // Resolve the iterator type based on the iteration space type.
2658 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2659 definedIts[i].type = spaceTp.getIteratorType();
2660 }
2661 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2662 Region *body = result.addRegion();
2663 if (parser.parseRegion(*body, definedIts))
2664 return failure();
2665
2666 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2667 }
2668
2669 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2670
2671 // Parse the optional attribute list.
2672 if (parser.parseOptionalAttrDict(result.attributes))
2673 return failure();
2674
2675 return success();
2676}
2677
2678void CoIterateOp::print(OpAsmPrinter &p) {
2679 p << " (";
2680 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2681 p << ")";
2682
2683 if (!getCrdUsedLvls().empty()) {
2684 p << " at(";
2685 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2686 p << ")";
2687 }
2688
2689 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2690
2691 p << " : (" << getIterSpaces().getTypes() << ")";
2692 if (!getInitArgs().empty())
2693 p.printArrowTypeList(getInitArgs().getTypes());
2694
2695 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2696 p.printNewline();
2697 p << "case ";
2698 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2699 getRegionDefinedSpace(idx));
2700 p << " ";
2701 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2702 /*printBlockTerminators=*/!getInitArgs().empty());
2703 }
2704}
2705
2706ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2707 return cast<sparse_tensor::YieldOp>(
2708 getRegion(regionIdx).getBlocks().front().getTerminator())
2709 .getResults();
2710}
2711
2712LogicalResult CoIterateOp::verifyRegions() {
2713 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2714 if (getNumRegionIterArgs() != getNumResults())
2715 return emitOpError(
2716 "mismatch in number of basic block args and defined values");
2717
2718 auto initArgs = getInitArgs();
2719 auto iterArgs = getRegionIterArgs(r);
2720 auto yieldVals = getYieldedValues(r);
2721 auto opResults = getResults();
2722 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2723 opResults.size()})) {
2724 return emitOpError()
2725 << "number mismatch between iter args and results on " << r
2726 << "th region";
2727 }
2728
2729 for (auto [i, init, iter, yield, ret] :
2730 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2731 if (init.getType() != ret.getType())
2732 return emitOpError()
2733 << "types mismatch between " << i
2734 << "th iter operand and defined value on " << r << "th region";
2735 if (iter.getType() != ret.getType())
2736 return emitOpError() << "types mismatch between " << i
2737 << "th iter region arg and defined value on " << r
2738 << "th region";
2739 if (yield.getType() != ret.getType())
2740 return emitOpError()
2741 << "types mismatch between " << i
2742 << "th yield value and defined value on " << r << "th region";
2743 }
2744 }
2745
2746 auto cases = getRegionDefinedSpaces();
2747 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2748 if (set.size() != getNumRegions())
2749 return emitOpError("contains duplicated cases.");
2750
2751 return success();
2752}
2753
2754SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2755 SmallVector<Region *> ret;
2756 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2757 for (Region &r : getCaseRegions())
2758 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2759 ret.push_back(&r);
2760
2761 return ret;
2762}
2763
2764//===----------------------------------------------------------------------===//
2765// Sparse Tensor Dialect Setups.
2766//===----------------------------------------------------------------------===//
2767
2768/// Materialize a single constant operation from a given attribute value with
2769/// the desired resultant type.
2770Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2771 Attribute value, Type type,
2772 Location loc) {
2773 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2774 return op;
2775 return nullptr;
2776}
2777
2778void SparseTensorDialect::initialize() {
2779 addAttributes<
2780#define GET_ATTRDEF_LIST
2781#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2782 >();
2783 addTypes<
2784#define GET_TYPEDEF_LIST
2785#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2786 >();
2787 addOperations<
2788#define GET_OP_LIST
2789#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2790 >();
2791 declarePromisedInterfaces<
2792 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2793 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2794 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2795}
2796
2797#define GET_OP_CLASSES
2798#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2799
2800#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:565
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:67
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
BlockArgListType getArguments()
Definition Block.h:87
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition Enums.h:402
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
bool isOrderedLT(LevelType lt)
Definition Enums.h:425
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
uint64_t getN(LevelType lt)
Definition Enums.h:442
unsigned FieldIndex
The type of field indices.
llvm::hash_code hash_value(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition Enums.h:413
uint64_t getM(LevelType lt)
Definition Enums.h:443
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition Enums.h:299