MLIR 23.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
17
22#include "mlir/IR/Builders.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50static llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144 return;
145}
146
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
156
157 const Type specType = StorageSpecifierType::get(stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164
165 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
184 unsigned numFields = 0;
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
194 unsigned numFields = 0; // one value memref
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl) const {
210 unsigned stride = 1;
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(getOffset());
269 os << ", ";
270 os << getStaticString(getSize());
271 os << ", ";
272 os << getStaticString(getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(printer.getStream());
278}
279
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 parser.getCurrentLocation(),
287 "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(parser.parseLParen()) ||
302 failed(parseOptionalStaticSlice(offset, parser)) ||
303 failed(parser.parseComma()) ||
304 failed(parseOptionalStaticSlice(size, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(stride, parser)) ||
307 failed(parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311 offset, size, stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(0, 0);
354}
355
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
362}
363
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365 return withExplicitVal(Attribute());
366}
367
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
374}
375
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377 return withImplicitVal(Attribute());
378}
379
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385}
386
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389}
390
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
395}
396
397bool SparseTensorEncodingAttr::isAllDense() const {
398 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399}
400
401bool SparseTensorEncodingAttr::isAllOrdered() const {
402 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403}
404
405Type SparseTensorEncodingAttr::getCrdElemType() const {
406 if (!getImpl())
407 return nullptr;
408 if (getCrdWidth())
409 return IntegerType::get(getContext(), getCrdWidth());
410 return IndexType::get(getContext());
411}
412
413Type SparseTensorEncodingAttr::getPosElemType() const {
414 if (!getImpl())
415 return nullptr;
416 if (getPosWidth())
417 return IntegerType::get(getContext(), getPosWidth());
418 return IndexType::get(getContext());
419}
420
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape) const {
423 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424 return MemRefType::get(shape, getCrdElemType());
425}
426
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape) const {
429 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430 return MemRefType::get(shape, getPosElemType());
431}
432
433bool SparseTensorEncodingAttr::isIdentity() const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435}
436
437bool SparseTensorEncodingAttr::isPermutation() const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439}
440
441Dimension SparseTensorEncodingAttr::getDimRank() const {
442 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445}
446
447Level SparseTensorEncodingAttr::getLvlRank() const {
448 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
450}
451
452LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453 if (!getImpl())
454 return LevelFormat::Batch;
455 assert(l < getLvlRank() && "Level is out of bounds");
456 return getLvlTypes()[l];
457}
458
459bool SparseTensorEncodingAttr::isSlice() const {
460 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
462}
463
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466 assert(isSlice() && "Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() && "Dimension is out of bounds");
469 return dimSlices[dim];
470}
471
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474 return getDimSlice(dim).getStaticOffset();
475}
476
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479 return getDimSlice(dim).getStaticStride();
480}
481
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484 return getStaticDimSliceOffset(toDim(*this, lvl));
485}
486
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489 return getStaticDimSliceStride(toDim(*this, lvl));
490}
491
492SmallVector<int64_t>
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir) const {
495 if (isIdentity())
496 return SmallVector<int64_t>(srcShape);
497
498 SmallVector<int64_t> ret;
499 unsigned rank =
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501 ret.reserve(rank);
502
503 if (isPermutation()) {
504 for (unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506 : toLvl(*this, r);
507 ret.push_back(srcShape[trans]);
508 }
509 return ret;
510 }
511
512 // Handle non-permutation maps.
513 AffineMap transMap =
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515
516 // Check if transMap is valid. There are cases where the lvlToDim map is
517 // uninitialized due to the format used, e.g. ELL. This is visible as
518 // inferring lvlToDim (see inferLvlToDim function below) may return an
519 // uninitialized affine map. Fallback to dynamic shapes.
520 if (!transMap) {
521 ret.resize(rank, ShapedType::kDynamic);
522 return ret;
523 }
524
525 SmallVector<AffineExpr> dimRep;
526 dimRep.reserve(srcShape.size());
527 for (int64_t sz : srcShape) {
528 if (ShapedType::isStatic(sz)) {
529 // Push back the max coordinate for the given dimension/level size.
530 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
531 } else {
532 // A dynamic size, use a AffineDimExpr to symbolize the value.
533 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
534 }
535 };
536
537 // The number of symbols information is included inside the `dimToLvl` map
538 // during parsing. Here, we're extracting it to be used when simplifying the
539 // affine expression.
540 unsigned numSymbols = getDimToLvl().getNumSymbols();
541
542 for (AffineExpr exp : transMap.getResults()) {
543 // Do constant propagation on the affine map.
544 AffineExpr evalExp = simplifyAffineExpr(exp.replaceDims(dimRep),
545 srcShape.size(), numSymbols);
546 // use llvm namespace here to avoid ambiguity
547 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
548 ret.push_back(c.getValue() + 1);
549 } else {
550 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
551 mod && mod.getKind() == AffineExprKind::Mod) {
552 // We can still infer a static bound for expressions in form
553 // "d % constant" since d % constant \in [0, constant).
554 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
555 ret.push_back(bound.getValue());
556 continue;
557 }
558 }
559 ret.push_back(ShapedType::kDynamic);
560 }
561 }
562 assert(ret.size() == rank);
563 return ret;
564}
565
567SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
568 ValueRange crds,
569 CrdTransDirectionKind dir) const {
570 if (!getImpl())
571 return crds;
572
573 SmallVector<Type> retType(
574 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
575 builder.getIndexType());
576 auto transOp =
577 CrdTranslateOp::create(builder, loc, retType, crds, dir, *this);
578 return transOp.getOutCrds();
579}
580
581Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
582 // Open "<{" part.
583 if (failed(parser.parseLess()))
584 return {};
585 if (failed(parser.parseLBrace()))
586 return {};
587
588 // Process the data from the parsed dictionary value into struct-like data.
589 SmallVector<LevelType> lvlTypes;
590 SmallVector<SparseTensorDimSliceAttr> dimSlices;
591 AffineMap dimToLvl = {};
592 AffineMap lvlToDim = {};
593 unsigned posWidth = 0;
594 unsigned crdWidth = 0;
595 Attribute explicitVal;
596 Attribute implicitVal;
597 StringRef attrName;
598 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
599 "explicitVal", "implicitVal"};
600 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
601 // Detect admissible keyword.
602 auto *it = find(keys, attrName);
603 if (it == keys.end()) {
604 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
605 return {};
606 }
607 unsigned keyWordIndex = it - keys.begin();
608 // Consume the `=` after keys
609 if (failed(parser.parseEqual()))
610 return {};
611 // Dispatch on keyword.
612 switch (keyWordIndex) {
613 case 0: { // map
614 ir_detail::DimLvlMapParser cParser(parser);
615 auto res = cParser.parseDimLvlMap();
616 if (failed(res))
617 return {};
618 const auto &dlm = *res;
619
620 const Level lvlRank = dlm.getLvlRank();
621 for (Level lvl = 0; lvl < lvlRank; lvl++)
622 lvlTypes.push_back(dlm.getLvlType(lvl));
623
624 const Dimension dimRank = dlm.getDimRank();
625 for (Dimension dim = 0; dim < dimRank; dim++)
626 dimSlices.push_back(dlm.getDimSlice(dim));
627 // NOTE: the old syntax requires an all-or-nothing approach to
628 // `dimSlices`; therefore, if any slice actually exists then we need
629 // to convert null-DSA into default/nop DSA.
630 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
631 return static_cast<bool>(slice.getImpl());
632 };
633 if (llvm::any_of(dimSlices, isDefined)) {
634 const auto defaultSlice =
635 SparseTensorDimSliceAttr::get(parser.getContext());
636 for (Dimension dim = 0; dim < dimRank; dim++)
637 if (!isDefined(dimSlices[dim]))
638 dimSlices[dim] = defaultSlice;
639 } else {
640 dimSlices.clear();
641 }
642
643 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
644 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
645 break;
646 }
647 case 1: { // posWidth
648 Attribute attr;
649 if (failed(parser.parseAttribute(attr)))
650 return {};
651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652 if (!intAttr) {
653 parser.emitError(parser.getNameLoc(),
654 "expected an integral position bitwidth");
655 return {};
656 }
657 posWidth = intAttr.getInt();
658 break;
659 }
660 case 2: { // crdWidth
661 Attribute attr;
662 if (failed(parser.parseAttribute(attr)))
663 return {};
664 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
665 if (!intAttr) {
666 parser.emitError(parser.getNameLoc(),
667 "expected an integral index bitwidth");
668 return {};
669 }
670 crdWidth = intAttr.getInt();
671 break;
672 }
673 case 3: { // explicitVal
674 Attribute attr;
675 if (failed(parser.parseAttribute(attr)))
676 return {};
677 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
678 explicitVal = result;
679 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
680 explicitVal = result;
681 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
682 explicitVal = result;
683 } else {
684 parser.emitError(parser.getNameLoc(),
685 "expected a numeric value for explicitVal");
686 return {};
687 }
688 break;
689 }
690 case 4: { // implicitVal
691 Attribute attr;
692 if (failed(parser.parseAttribute(attr)))
693 return {};
694 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
695 implicitVal = result;
696 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
697 implicitVal = result;
698 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
699 implicitVal = result;
700 } else {
701 parser.emitError(parser.getNameLoc(),
702 "expected a numeric value for implicitVal");
703 return {};
704 }
705 break;
706 }
707 } // switch
708 // Only last item can omit the comma.
709 if (parser.parseOptionalComma().failed())
710 break;
711 }
712
713 // Close "}>" part.
714 if (failed(parser.parseRBrace()))
715 return {};
716 if (failed(parser.parseGreater()))
717 return {};
718
719 // Construct struct-like storage for attribute.
720 if (!lvlToDim || lvlToDim.isEmpty()) {
721 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
722 }
723 return parser.getChecked<SparseTensorEncodingAttr>(
724 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
725 explicitVal, implicitVal, dimSlices);
726}
727
728void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
729 auto map = static_cast<AffineMap>(getDimToLvl());
730 // Empty affine map indicates identity map
731 if (!map)
732 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
733 printer << "<{ map = ";
734 printSymbols(map, printer);
735 printer << '(';
736 printDimensions(map, printer, getDimSlices());
737 printer << ") -> (";
738 printLevels(map, printer, getLvlTypes());
739 printer << ')';
740 // Print remaining members only for non-default values.
741 if (getPosWidth())
742 printer << ", posWidth = " << getPosWidth();
743 if (getCrdWidth())
744 printer << ", crdWidth = " << getCrdWidth();
745 if (getExplicitVal()) {
746 printer << ", explicitVal = " << getExplicitVal();
747 }
748 if (getImplicitVal())
749 printer << ", implicitVal = " << getImplicitVal();
750 printer << " }>";
751}
752
753void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
754 AsmPrinter &printer) const {
755 if (map.getNumSymbols() == 0)
756 return;
757 printer << '[';
758 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
759 printer << 's' << i << ", ";
760 if (map.getNumSymbols() >= 1)
761 printer << 's' << map.getNumSymbols() - 1;
762 printer << ']';
763}
764
765void SparseTensorEncodingAttr::printDimensions(
766 AffineMap &map, AsmPrinter &printer,
767 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
768 if (!dimSlices.empty()) {
769 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
770 printer << 'd' << i << " : " << dimSlices[i] << ", ";
771 if (map.getNumDims() >= 1) {
772 printer << 'd' << map.getNumDims() - 1 << " : "
773 << dimSlices[map.getNumDims() - 1];
774 }
775 } else {
776 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
777 printer << 'd' << i << ", ";
778 if (map.getNumDims() >= 1)
779 printer << 'd' << map.getNumDims() - 1;
780 }
781}
782
783void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
784 ArrayRef<LevelType> lvlTypes) const {
785 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
786 map.getResult(i).print(printer.getStream());
787 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
788 }
789 if (map.getNumResults() >= 1) {
790 auto lastIndex = map.getNumResults() - 1;
791 map.getResult(lastIndex).print(printer.getStream());
792 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
793 }
794}
795
796LogicalResult SparseTensorEncodingAttr::verify(
797 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
798 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
799 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
800 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
801 if (!acceptBitWidth(posWidth))
802 return emitError() << "unexpected position bitwidth: " << posWidth;
803 if (!acceptBitWidth(crdWidth))
804 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
805
806 // Verify every COO segment.
807 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
808 while (it != lvlTypes.end()) {
809 if (it == lvlTypes.begin() ||
811 return emitError() << "expected compressed or loose_compressed level "
812 "before singleton level";
813
814 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
815 if (!std::all_of(it, curCOOEnd, isSingletonLT))
816 return emitError() << "expected all singleton lvlTypes "
817 "following a singleton level";
818 // We can potentially support mixed SoA/AoS singleton levels.
819 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
820 return it->isa<LevelPropNonDefault::SoA>() ==
822 })) {
823 return emitError() << "expected all singleton lvlTypes stored in the "
824 "same memory layout (SoA vs AoS).";
825 }
826 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
827 }
828
829 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
830 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
831 return emitError() << "Batch lvlType can only be leading levels.";
832
833 // SoA property can only be applied on singleton level.
834 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
835 return lt.isa<LevelPropNonDefault::SoA>();
836 });
837 if (llvm::any_of(soaLvls, [](LevelType lt) {
838 return !lt.isa<LevelFormat::Singleton>();
839 })) {
840 return emitError() << "SoA is only applicable to singleton lvlTypes.";
841 }
842
843 // Dense levels cannot follow a non-unique level. The iteration model for
844 // dense levels requires exactly one parent position to linearize into a
845 // contiguous range, but a non-unique parent provides two cursor values
846 // (segment start and end), which the dense level cannot handle.
847 for (auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
848 if (isDenseLT(lt) && !isUniqueLT(lvlTypes[i - 1]))
849 return emitError() << "dense level cannot follow a non-unique level";
850 }
851
852 // TODO: audit formats that actually are supported by backend.
853 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
854 it != std::end(lvlTypes)) {
855 if (it != lvlTypes.end() - 1)
856 return emitError() << "expected n_out_of_m to be the last level type";
857 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
858 return emitError() << "expected all dense lvlTypes "
859 "before a n_out_of_m level";
860 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
861 if (!isBlockSparsity(dimToLvl)) {
862 return emitError()
863 << "expected 1xm block structure for n_out_of_m level";
864 }
865 auto sizes = getBlockSize(dimToLvl);
866 unsigned coefficient = 0;
867 for (const auto &elem : sizes) {
868 if (elem != 0) {
869 if (elem != coefficient && coefficient != 0) {
870 return emitError() << "expected only one blocked level "
871 "with the same coefficients";
872 }
873 coefficient = elem;
874 }
875 }
876 if (coefficient != getM(*it)) {
877 return emitError() << "expected coeffiencts of Affine expressions "
878 "to be equal to m of n_out_of_m level";
879 }
880 }
881 }
882 // Before we can check that the level-rank is consistent/coherent
883 // across all fields, we need to define it. The source-of-truth for
884 // the `getLvlRank` method is the length of the level-types array,
885 // since it must always be provided and have full rank; therefore we
886 // use that same source-of-truth here.
887 const Level lvlRank = lvlTypes.size();
888 if (lvlRank == 0)
889 return emitError() << "expected a non-empty array for lvlTypes";
890 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
891 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
892 if (dimToLvl) {
893 if (dimToLvl.getNumResults() != lvlRank)
894 return emitError()
895 << "level-rank mismatch between dimToLvl and lvlTypes: "
896 << dimToLvl.getNumResults() << " != " << lvlRank;
897 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
898 // Symbols can't be inferred but are acceptable.
899 if (!inferRes && dimToLvl.getNumSymbols() == 0)
900 return emitError() << "failed to infer lvlToDim from dimToLvl";
901 if (lvlToDim && (inferRes != lvlToDim))
902 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
903 if (dimRank > lvlRank)
904 return emitError() << "unexpected dimToLvl mapping from " << dimRank
905 << " to " << lvlRank;
906 }
907 if (!dimSlices.empty()) {
908 if (dimSlices.size() != dimRank)
909 return emitError()
910 << "dimension-rank mismatch between dimSlices and dimToLvl: "
911 << dimSlices.size() << " != " << dimRank;
912 // Compiler support for `dimSlices` currently requires that the two
913 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
914 if (dimRank != lvlRank)
915 return emitError()
916 << "dimSlices expected dimension-rank to match level-rank: "
917 << dimRank << " != " << lvlRank;
918 }
919 return success();
920}
921
922LogicalResult SparseTensorEncodingAttr::verifyEncoding(
923 ArrayRef<Size> dimShape, Type elementType,
924 function_ref<InFlightDiagnostic()> emitError) const {
925 // Check structural integrity. In particular, this ensures that the
926 // level-rank is coherent across all the fields.
927 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
928 getPosWidth(), getCrdWidth(), getExplicitVal(),
929 getImplicitVal(), getDimSlices())))
930 return failure();
931 // Check integrity with tensor type specifics. In particular, we
932 // need only check that the dimension-rank of the tensor agrees with
933 // the dimension-rank of the encoding.
934 const Dimension dimRank = dimShape.size();
935 if (dimRank == 0)
936 return emitError() << "expected non-scalar sparse tensor";
937 if (getDimRank() != dimRank)
938 return emitError()
939 << "dimension-rank mismatch between encoding and tensor shape: "
940 << getDimRank() << " != " << dimRank;
941 if (auto expVal = getExplicitVal()) {
942 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
943 if (attrType != elementType) {
944 return emitError() << "explicit value type mismatch between encoding and "
945 << "tensor element type: " << attrType
946 << " != " << elementType;
947 }
948 }
949 if (auto impVal = getImplicitVal()) {
950 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
951 if (attrType != elementType) {
952 return emitError() << "implicit value type mismatch between encoding and "
953 << "tensor element type: " << attrType
954 << " != " << elementType;
955 }
956 // Currently, we only support zero as the implicit value.
957 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
958 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
959 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
960 if ((impFVal && impFVal.getValue().isNonZero()) ||
961 (impIntVal && !impIntVal.getValue().isZero()) ||
962 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
963 impComplexVal.getReal().isNonZero()))) {
964 return emitError() << "implicit value must be zero";
965 }
966 }
967 return success();
968}
969
970Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
971 SmallVector<COOSegment> coo = getCOOSegments();
972 assert(coo.size() == 1 || coo.empty());
973 if (!coo.empty() && coo.front().isAoS()) {
974 return coo.front().lvlRange.first;
975 }
976 return getLvlRank();
977}
978
979SmallVector<COOSegment>
980mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
981 SmallVector<COOSegment> ret;
982 if (getLvlRank() <= 1)
983 return ret;
984
985 ArrayRef<LevelType> lts = getLvlTypes();
986 Level l = 0;
987 while (l < getLvlRank()) {
988 auto lt = lts[l];
990 auto cur = lts.begin() + l;
991 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
992 return !lt.isa<LevelFormat::Singleton>();
993 });
994 unsigned cooLen = std::distance(cur, end);
995 if (cooLen > 1) {
996 // To support mixed SoA/AoS COO, we should break the segment when the
997 // storage scheme changes, for now we faithfully assume that all
998 // consecutive singleton levels have the same storage format as verified
999 // STEA.
1000 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
1001 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
1002 }
1003 l += cooLen;
1004 } else {
1005 l++;
1006 }
1007 }
1008 return ret;
1009}
1010
1011//===----------------------------------------------------------------------===//
1012// SparseTensorType Methods.
1013//===----------------------------------------------------------------------===//
1014
1016 bool isUnique) const {
1017 if (!hasEncoding())
1018 return false;
1019 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1020 return false;
1021 for (Level l = startLvl + 1; l < lvlRank; ++l)
1022 if (!isSingletonLvl(l))
1023 return false;
1024 // If isUnique is true, then make sure that the last level is unique,
1025 // that is, when lvlRank == 1, the only compressed level is unique,
1026 // and when lvlRank > 1, the last singleton is unique.
1027 return !isUnique || isUniqueLvl(lvlRank - 1);
1028}
1029
1030RankedTensorType
1032 SmallVector<LevelType> lvlTypes;
1033 lvlTypes.reserve(lvlRank);
1034 // A non-unique compressed level at beginning (unless this is
1035 // also the last level, then it is unique).
1036 lvlTypes.push_back(
1037 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1038 if (lvlRank > 1) {
1039 // Followed by n-2 non-unique singleton levels.
1040 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1041 *buildLevelType(LevelFormat::Singleton, ordered, false));
1042 // Ends by a unique singleton level.
1043 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1044 }
1045 auto enc = SparseTensorEncodingAttr::get(
1046 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1048 return RankedTensorType::get(getDimShape(), getElementType(), enc);
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// Convenience Methods.
1053//===----------------------------------------------------------------------===//
1054
1055SparseTensorEncodingAttr
1057 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1058 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1059 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1060 return mdtp.getEncoding();
1061 return nullptr;
1062}
1063
1065 MLIRContext *context) {
1066 auto map = static_cast<AffineMap>(dimToLvl);
1067 AffineMap lvlToDim;
1068 // Return an empty lvlToDim when inference is not successful.
1069 if (!map || map.getNumSymbols() != 0) {
1070 lvlToDim = AffineMap();
1071 } else if (map.isPermutation()) {
1072 lvlToDim = inversePermutation(map);
1073 } else if (isBlockSparsity(map)) {
1074 lvlToDim = inverseBlockSparsity(map, context);
1075 }
1076 return lvlToDim;
1077}
1078
1080 MLIRContext *context) {
1081 SmallVector<AffineExpr> lvlExprs;
1082 auto numLvls = dimToLvl.getNumResults();
1083 lvlExprs.reserve(numLvls);
1084 // lvlExprComponents stores information of the floordiv and mod operations
1085 // applied to the same dimension, so as to build the lvlToDim map.
1086 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1087 for (unsigned i = 0, n = numLvls; i < n; i++) {
1088 auto result = dimToLvl.getResult(i);
1089 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1090 if (result.getKind() == AffineExprKind::FloorDiv) {
1091 // Position of the dimension in dimToLvl.
1092 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1093 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1094 "expected only one floordiv for each dimension");
1095 SmallVector<AffineExpr, 3> components;
1096 // Level variable for floordiv.
1097 components.push_back(getAffineDimExpr(i, context));
1098 // Multiplier.
1099 components.push_back(binOp.getRHS());
1100 // Map key is the position of the dimension.
1101 lvlExprComponents[pos] = components;
1102 } else if (result.getKind() == AffineExprKind::Mod) {
1103 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1104 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1105 "expected floordiv before mod");
1106 // Add level variable for mod to the same vector
1107 // of the corresponding floordiv.
1108 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1109 } else {
1110 assert(false && "expected floordiv or mod");
1111 }
1112 } else {
1113 lvlExprs.push_back(getAffineDimExpr(i, context));
1114 }
1115 }
1116 // Build lvlExprs from lvlExprComponents.
1117 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1118 // would be [il, 2, ii]. It could be used to build the AffineExpr
1119 // i = il * 2 + ii in lvlToDim.
1120 for (auto &components : lvlExprComponents) {
1121 assert(components.second.size() == 3 &&
1122 "expected 3 components to build lvlExprs");
1123 auto mulOp = getAffineBinaryOpExpr(
1124 AffineExprKind::Mul, components.second[0], components.second[1]);
1125 auto addOp =
1126 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1127 lvlExprs.push_back(addOp);
1128 }
1129 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1130}
1131
1133 assert(isBlockSparsity(dimToLvl) &&
1134 "expected dimToLvl to be block sparsity for calling getBlockSize");
1135 SmallVector<unsigned> blockSize;
1136 for (auto result : dimToLvl.getResults()) {
1137 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1138 if (result.getKind() == AffineExprKind::Mod) {
1139 blockSize.push_back(
1140 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1141 }
1142 } else {
1143 blockSize.push_back(0);
1144 }
1145 }
1146 return blockSize;
1147}
1148
1150 if (!dimToLvl)
1151 return false;
1152 std::map<unsigned, int64_t> coeffientMap;
1153 bool hasBlock = false;
1154 for (auto result : dimToLvl.getResults()) {
1155 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1156 // Check for "dim op const".
1157 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1158 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1159 if (!dimOp || !conOp || conOp.getValue() <= 0)
1160 return false;
1161 // Inspect "dim / const" or "dim % const".
1162 auto pos = dimOp.getPosition();
1163 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1164 // Expect only one floordiv for each dimension.
1165 auto [it, inserted] = coeffientMap.try_emplace(pos);
1166 if (!inserted)
1167 return false;
1168 // Record coefficient of the floordiv.
1169 it->second = conOp.getValue();
1170 } else if (binOp.getKind() == AffineExprKind::Mod) {
1171 // Expect floordiv before mod.
1172 auto it = coeffientMap.find(pos);
1173 if (it == coeffientMap.end())
1174 return false;
1175 // Expect mod to have the same coefficient as floordiv.
1176 if (conOp.getValue() != it->second)
1177 return false;
1178 hasBlock = true;
1179 } else {
1180 return false;
1181 }
1182 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1183 auto pos = dimOp.getPosition();
1184 // Expect dim to be unset.
1185 if (!coeffientMap.try_emplace(pos, 0).second)
1186 return false;
1187 } else {
1188 return false;
1189 }
1190 }
1191 return hasBlock;
1192}
1193
1195 auto hasNonIdentityMap = [](Value v) {
1196 auto stt = tryGetSparseTensorType(v);
1197 return stt && !stt->isIdentity();
1198 };
1199
1200 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1201 llvm::any_of(op->getResults(), hasNonIdentityMap);
1202}
1203
1204Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1205 if (enc) {
1206 assert(enc.isPermutation() && "Non permutation map not supported");
1207 if (const auto dimToLvl = enc.getDimToLvl())
1208 return dimToLvl.getDimPosition(l);
1209 }
1210 return l;
1211}
1212
1213Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1214 if (enc) {
1215 assert(enc.isPermutation() && "Non permutation map not supported");
1216 if (const auto lvlToDim = enc.getLvlToDim())
1217 return lvlToDim.getDimPosition(d);
1218 }
1219 return d;
1220}
1221
1222/// We normalized sparse tensor encoding attribute by always using
1223/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1224/// as other variants) lead to the same storage specifier type, and stripping
1225/// irrelevant fields that do not alter the sparse tensor memory layout.
1226static SparseTensorEncodingAttr
1227getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1229 for (auto lt : enc.getLvlTypes())
1230 lts.push_back(lt.stripStorageIrrelevantProperties());
1231
1232 return SparseTensorEncodingAttr::get(
1233 enc.getContext(), lts,
1234 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1235 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1236 // Always use `index` for memSize and lvlSize instead of reusing
1237 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1238 // value for different bitwidth, it also avoids casting between index and
1239 // integer (returned by DimOp)
1240 0, 0,
1241 Attribute(), // explicitVal (irrelevant to storage specifier)
1242 Attribute(), // implicitVal (irrelevant to storage specifier)
1243 enc.getDimSlices());
1244}
1245
1246StorageSpecifierType
1247StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1248 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1249}
1250
1251StorageSpecifierType
1252StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1253 MLIRContext *ctx,
1254 SparseTensorEncodingAttr encoding) {
1255 return Base::getChecked(emitError, ctx,
1257}
1258
1259//===----------------------------------------------------------------------===//
1260// SparseTensorDialect Operations.
1261//===----------------------------------------------------------------------===//
1262
1263static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1264 return success(lvl < getSparseTensorType(tensor).getLvlRank());
1265}
1266
1267static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1268 const Type etp = getMemRefType(mem).getElementType();
1269 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1270}
1271
1272static LogicalResult verifySparsifierGetterSetter(
1273 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1275 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1276 return op->emitError(
1277 "redundant level argument for querying value memory size");
1278 }
1279
1280 const auto enc = md.getType().getEncoding();
1281 const Level lvlRank = enc.getLvlRank();
1282
1283 if (mdKind == StorageSpecifierKind::DimOffset ||
1284 mdKind == StorageSpecifierKind::DimStride)
1285 if (!enc.isSlice())
1286 return op->emitError("requested slice data on non-slice tensor");
1287
1288 if (mdKind != StorageSpecifierKind::ValMemSize) {
1289 if (!lvl)
1290 return op->emitError("missing level argument");
1291
1292 const Level l = lvl.value();
1293 if (l >= lvlRank)
1294 return op->emitError("requested level is out of bounds");
1295
1296 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1297 return op->emitError(
1298 "requested position memory size on a singleton level");
1299 }
1300 return success();
1301}
1302
1304 switch (kind) {
1306 return stt.getCrdType();
1308 return stt.getPosType();
1310 return stt.getElementType();
1312 return nullptr;
1313 }
1314 llvm_unreachable("Unrecognizable FieldKind");
1315}
1316
1317static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1318 SparseTensorType stt,
1319 RankedTensorType valTp,
1320 TypeRange lvlTps) {
1321 if (requiresStaticShape && !stt.hasStaticDimShape())
1322 return op->emitError("the sparse-tensor must have static shape");
1323 if (!stt.hasEncoding())
1324 return op->emitError("the sparse-tensor must have an encoding attribute");
1325
1326 // Verifies the trailing COO.
1327 Level cooStartLvl = stt.getAoSCOOStart();
1328 if (cooStartLvl < stt.getLvlRank()) {
1329 // We only supports trailing COO for now, must be the last input.
1330 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1331 // The coordinates should be in shape of <? x rank>
1332 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1333 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1334 return op->emitError("input/output trailing COO level-ranks don't match");
1335 }
1336 }
1337
1338 // Verifies that all types match.
1339 StorageLayout layout(stt.getEncoding());
1340 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1341 return op->emitError("inconsistent number of fields between input/output");
1342
1343 unsigned idx = 0;
1344 bool misMatch = false;
1345 layout.foreachField([&idx, &misMatch, stt, valTp,
1346 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1347 Level lvl, LevelType lt) -> bool {
1349 return true;
1350
1351 Type inputTp = nullptr;
1352 if (fKind == SparseTensorFieldKind::ValMemRef) {
1353 inputTp = valTp;
1354 } else {
1355 assert(fid == idx && stt.getLvlType(lvl) == lt);
1356 inputTp = lvlTps[idx++];
1357 }
1358 // The input element type and expected element type should match.
1359 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1360 Type expElemTp = getFieldElemType(stt, fKind);
1361 if (inpElemTp != expElemTp) {
1362 misMatch = true;
1363 return false; // to terminate the iteration
1364 }
1365 return true;
1366 });
1367
1368 if (misMatch)
1369 return op->emitError("input/output element-types don't match");
1370 return success();
1371}
1372
1373LogicalResult AssembleOp::verify() {
1374 RankedTensorType valuesTp = getValues().getType();
1375 const auto lvlsTp = getLevels().getTypes();
1376 const auto resTp = getSparseTensorType(getResult());
1377 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1378}
1379
1380LogicalResult DisassembleOp::verify() {
1381 if (getOutValues().getType() != getRetValues().getType())
1382 return emitError("output values and return value type mismatch");
1383
1384 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1385 if (ot.getType() != rt.getType())
1386 return emitError("output levels and return levels type mismatch");
1387
1388 RankedTensorType valuesTp = getRetValues().getType();
1389 const auto lvlsTp = getRetLevels().getTypes();
1390 const auto srcTp = getSparseTensorType(getTensor());
1391 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1392}
1393
1394LogicalResult ConvertOp::verify() {
1395 RankedTensorType tp1 = getSource().getType();
1396 RankedTensorType tp2 = getDest().getType();
1397 if (tp1.getRank() != tp2.getRank())
1398 return emitError("unexpected conversion mismatch in rank");
1399 auto dstEnc =
1400 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1401 if (dstEnc && dstEnc.isSlice())
1402 return emitError("cannot convert to a sparse tensor slice");
1403
1404 auto shape1 = tp1.getShape();
1405 auto shape2 = tp2.getShape();
1406 // Accept size matches between the source and the destination type
1407 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1408 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1409 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1410 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1411 return emitError("unexpected conversion mismatch in dimension ") << d;
1412 return success();
1413}
1414
1415OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1416 if (getType() == getSource().getType())
1417 return getSource();
1418 return {};
1419}
1420
1421bool ConvertOp::needsExtraSort() {
1422 SparseTensorType srcStt = getSparseTensorType(getSource());
1423 SparseTensorType dstStt = getSparseTensorType(getDest());
1424
1425 // We do not need an extra sort when returning unordered sparse tensors or
1426 // dense tensor since dense tensor support random access.
1427 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1428 return false;
1429
1430 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1431 srcStt.hasSameDimToLvl(dstStt)) {
1432 return false;
1433 }
1434
1435 // Source and dest tensors are ordered in different ways. We only do direct
1436 // dense to sparse conversion when the dense input is defined by a sparse
1437 // constant. Note that we can theoretically always directly convert from dense
1438 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1439 // performance.
1440 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1441 if (isa<SparseElementsAttr>(constOp.getValue()))
1442 return false;
1443
1444 return true;
1445}
1446
1447LogicalResult CrdTranslateOp::verify() {
1448 uint64_t inRank = getEncoder().getLvlRank();
1449 uint64_t outRank = getEncoder().getDimRank();
1450
1451 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1452 std::swap(inRank, outRank);
1453
1454 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1455 return emitError("Coordinate rank mismatch with encoding");
1456
1457 return success();
1458}
1459
1460LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1461 SmallVectorImpl<OpFoldResult> &results) {
1462 if (getEncoder().isIdentity()) {
1463 results.assign(getInCrds().begin(), getInCrds().end());
1464 return success();
1465 }
1466 if (getEncoder().isPermutation()) {
1467 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1468 ? getEncoder().getDimToLvl()
1469 : getEncoder().getLvlToDim();
1470 for (AffineExpr exp : perm.getResults())
1471 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1472 return success();
1473 }
1474
1475 // Fuse dim2lvl/lvl2dim pairs.
1476 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1477 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1478 return v.getDefiningOp() == def;
1479 });
1480 if (!sameDef)
1481 return failure();
1482
1483 bool oppositeDir = def.getDirection() != getDirection();
1484 bool sameOracle =
1485 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1486 bool sameCount = def.getNumResults() == getInCrds().size();
1487 if (!oppositeDir || !sameOracle || !sameCount)
1488 return failure();
1489
1490 // The definition produces the coordinates in the same order as the input
1491 // coordinates.
1492 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1493 [](auto valuePair) {
1494 auto [lhs, rhs] = valuePair;
1495 return lhs == rhs;
1496 });
1497
1498 if (!sameOrder)
1499 return failure();
1500 // l1 = dim2lvl (lvl2dim l0)
1501 // ==> l0
1502 results.append(def.getInCrds().begin(), def.getInCrds().end());
1503 return success();
1504}
1505
1506void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1507 int64_t index) {
1508 Value val = arith::ConstantIndexOp::create(builder, state.location, index);
1509 return build(builder, state, source, val);
1510}
1511
1512LogicalResult LvlOp::verify() {
1513 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1514 auto stt = getSparseTensorType(getSource());
1515 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1516 return emitError(
1517 "Level index exceeds the rank of the input sparse tensor");
1518 }
1519 return success();
1520}
1521
1522std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1523 return getConstantIntValue(getIndex());
1524}
1525
1526Speculation::Speculatability LvlOp::getSpeculatability() {
1527 auto constantIndex = getConstantLvlIndex();
1528 if (!constantIndex)
1530
1531 assert(constantIndex <
1532 cast<RankedTensorType>(getSource().getType()).getRank());
1534}
1535
1536OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1537 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1538 if (!lvlIndex)
1539 return {};
1540
1541 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1542 auto stt = getSparseTensorType(getSource());
1543 if (lvl >= stt.getLvlRank()) {
1544 // Follows the same convention used by tensor.dim operation. Out of bound
1545 // indices produce undefined behavior but are still valid IR. Don't choke on
1546 // them.
1547 return {};
1548 }
1549
1550 // Helper lambda to build an IndexAttr.
1551 auto getIndexAttr = [this](int64_t lvlSz) {
1552 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1553 };
1554
1555 SmallVector<Size> lvlShape = stt.getLvlShape();
1556 if (ShapedType::isStatic(lvlShape[lvl]))
1557 return getIndexAttr(lvlShape[lvl]);
1558
1559 return {};
1560}
1561
1562void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1563 SparseTensorEncodingAttr dstEnc, Value source) {
1564 auto srcStt = getSparseTensorType(source);
1565 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1566 SmallVector<int64_t> dstDimShape =
1567 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1568 auto dstTp =
1569 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1570 return build(odsBuilder, odsState, dstTp, source);
1571}
1572
1573LogicalResult ReinterpretMapOp::verify() {
1574 auto srcStt = getSparseTensorType(getSource());
1575 auto dstStt = getSparseTensorType(getDest());
1576 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1577 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1578
1579 if (srcLvlTps.size() != dstLvlTps.size())
1580 return emitError("Level rank mismatch between source/dest tensors");
1581
1582 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1583 if (srcLvlTp != dstLvlTp)
1584 return emitError("Level type mismatch between source/dest tensors");
1585
1586 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1587 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1588 return emitError("Crd/Pos width mismatch between source/dest tensors");
1589 }
1590
1591 if (srcStt.getElementType() != dstStt.getElementType())
1592 return emitError("Element type mismatch between source/dest tensors");
1593
1594 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1595 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1596 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1597 if (srcLvlSz != dstLvlSz) {
1598 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1599 // compatible to <3x4>? For now, we require all the level sizes to be
1600 // *exactly* matched for simplicity.
1601 return emitError("Level size mismatch between source/dest tensors");
1602 }
1603 }
1604
1605 return success();
1606}
1607
1608OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1609 if (getSource().getType() == getDest().getType())
1610 return getSource();
1611
1612 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1613 // A -> B, B -> A ==> A
1614 if (def.getSource().getType() == getDest().getType())
1615 return def.getSource();
1616 }
1617 return {};
1618}
1619
1620template <typename ToBufferOp>
1621static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1622 PropertyRef prop, RegionRange region,
1624 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1625 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1626 Type elemTp = nullptr;
1627 bool withStride = false;
1628 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1629 elemTp = stt.getPosType();
1630 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1631 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1632 elemTp = stt.getCrdType();
1633 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1634 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1635 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1636 elemTp = stt.getElementType();
1637 }
1638
1639 assert(elemTp && "unhandled operation.");
1640 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1641 bufShape.push_back(ShapedType::kDynamic);
1642
1643 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1644 stt.getContext(), ShapedType::kDynamic,
1645 {ShapedType::kDynamic})
1646 : StridedLayoutAttr();
1647 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1648 return success();
1649}
1650
1651LogicalResult ToPositionsOp::verify() {
1652 auto stt = getSparseTensorType(getTensor());
1653 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1654 return emitError("requested level is out of bounds");
1655 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1656 return emitError("unexpected type for positions");
1657 return success();
1658}
1659
1660LogicalResult
1661ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1662 ValueRange ops, DictionaryAttr attr,
1663 PropertyRef prop, RegionRange region,
1664 SmallVectorImpl<mlir::Type> &ret) {
1665 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1666}
1667
1668LogicalResult ToCoordinatesOp::verify() {
1669 auto stt = getSparseTensorType(getTensor());
1670 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1671 return emitError("requested level is out of bounds");
1672 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1673 return emitError("unexpected type for coordinates");
1674 return success();
1675}
1676
1677LogicalResult
1678ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1679 ValueRange ops, DictionaryAttr attr,
1680 PropertyRef prop, RegionRange region,
1681 SmallVectorImpl<mlir::Type> &ret) {
1682 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1683}
1684
1685LogicalResult ToCoordinatesBufferOp::verify() {
1686 auto stt = getSparseTensorType(getTensor());
1687 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1688 return emitError("expected sparse tensor with a COO region");
1689 return success();
1690}
1691
1692LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1693 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1694 DictionaryAttr attr, PropertyRef prop, RegionRange region,
1695 SmallVectorImpl<mlir::Type> &ret) {
1696 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1697 ret);
1698}
1699
1700LogicalResult ToValuesOp::verify() {
1701 auto stt = getSparseTensorType(getTensor());
1702 auto mtp = getMemRefType(getResult());
1703 if (stt.getElementType() != mtp.getElementType())
1704 return emitError("unexpected mismatch in element types");
1705 return success();
1706}
1707
1708LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1709 std::optional<Location> loc,
1710 ValueRange ops, DictionaryAttr attr,
1711 PropertyRef prop, RegionRange region,
1712 SmallVectorImpl<mlir::Type> &ret) {
1713 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1714}
1715
1716LogicalResult ToSliceOffsetOp::verify() {
1717 auto rank = getSlice().getType().getRank();
1718 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1719 return emitError("requested dimension out of bound");
1720 return success();
1721}
1722
1723LogicalResult ToSliceStrideOp::verify() {
1724 auto rank = getSlice().getType().getRank();
1725 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1726 return emitError("requested dimension out of bound");
1727 return success();
1728}
1729
1730LogicalResult GetStorageSpecifierOp::verify() {
1731 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1732 getSpecifier(), getOperation());
1733}
1734
1735template <typename SpecifierOp>
1736static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1737 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1738}
1739
1740OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1741 const StorageSpecifierKind kind = getSpecifierKind();
1742 const auto lvl = getLevel();
1743 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1744 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1745 return op.getValue();
1746 return {};
1747}
1748
1749LogicalResult SetStorageSpecifierOp::verify() {
1750 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1751 getSpecifier(), getOperation());
1752}
1753
1754template <class T>
1755static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1756 const char *regionName,
1757 TypeRange inputTypes, Type outputType) {
1758 unsigned numArgs = region.getNumArguments();
1759 unsigned expectedNum = inputTypes.size();
1760 if (numArgs != expectedNum)
1761 return op->emitError() << regionName << " region must have exactly "
1762 << expectedNum << " arguments";
1763
1764 for (unsigned i = 0; i < numArgs; i++) {
1765 Type typ = region.getArgument(i).getType();
1766 if (typ != inputTypes[i])
1767 return op->emitError() << regionName << " region argument " << (i + 1)
1768 << " type mismatch";
1769 }
1770 Block &block = region.front();
1771 if (!block.mightHaveTerminator())
1772 return op->emitError() << regionName
1773 << " region must end with a terminator";
1774
1775 Operation *term = block.getTerminator();
1776 YieldOp yield = dyn_cast<YieldOp>(term);
1777 if (!yield)
1778 return op->emitError() << regionName
1779 << " region must end with sparse_tensor.yield";
1780 if (!yield.hasSingleResult() ||
1781 yield.getSingleResult().getType() != outputType)
1782 return op->emitError() << regionName << " region yield type mismatch";
1783
1784 return success();
1785}
1786
1787LogicalResult BinaryOp::verify() {
1788 NamedAttrList attrs = (*this)->getAttrs();
1789 Type leftType = getX().getType();
1790 Type rightType = getY().getType();
1791 Type outputType = getOutput().getType();
1792 Region &overlap = getOverlapRegion();
1793 Region &left = getLeftRegion();
1794 Region &right = getRightRegion();
1795
1796 // Check correct number of block arguments and return type for each
1797 // non-empty region.
1798 if (!overlap.empty()) {
1799 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1800 TypeRange{leftType, rightType}, outputType)))
1801 return failure();
1802 }
1803 if (!left.empty()) {
1804 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1805 outputType)))
1806 return failure();
1807 } else if (getLeftIdentity()) {
1808 if (leftType != outputType)
1809 return emitError("left=identity requires first argument to have the same "
1810 "type as the output");
1811 }
1812 if (!right.empty()) {
1813 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1814 outputType)))
1815 return failure();
1816 } else if (getRightIdentity()) {
1817 if (rightType != outputType)
1818 return emitError("right=identity requires second argument to have the "
1819 "same type as the output");
1820 }
1821 return success();
1822}
1823
1824LogicalResult UnaryOp::verify() {
1825 Type inputType = getX().getType();
1826 Type outputType = getOutput().getType();
1827
1828 // Check correct number of block arguments and return type for each
1829 // non-empty region.
1830 Region &present = getPresentRegion();
1831 if (!present.empty()) {
1832 if (failed(verifyNumBlockArgs(this, present, "present",
1833 TypeRange{inputType}, outputType)))
1834 return failure();
1835 }
1836 Region &absent = getAbsentRegion();
1837 if (!absent.empty()) {
1838 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1839 outputType)))
1840 return failure();
1841 // Absent branch can only yield invariant values.
1842 Block *absentBlock = &absent.front();
1843 Block *parent = getOperation()->getBlock();
1844 Value absentVal =
1845 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1846 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1847 if (arg.getOwner() == parent)
1848 return emitError("absent region cannot yield linalg argument");
1849 } else if (Operation *def = absentVal.getDefiningOp()) {
1850 if (!isa<arith::ConstantOp>(def) &&
1851 (def->getBlock() == absentBlock || def->getBlock() == parent))
1852 return emitError("absent region cannot yield locally computed value");
1853 }
1854 }
1855 return success();
1856}
1857
1858bool ConcatenateOp::needsExtraSort() {
1859 SparseTensorType dstStt = getSparseTensorType(*this);
1860 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1861 return false;
1862
1863 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1864 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1865 });
1866 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1867 // in all input/output buffers, and all input/output buffers have the same
1868 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1869 // CSC matrices along column).
1870 bool directLowerable =
1871 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1872 return !directLowerable;
1873}
1874
1875LogicalResult ConcatenateOp::verify() {
1876 const auto dstTp = getSparseTensorType(*this);
1877 const Dimension concatDim = getDimension();
1878 const Dimension dimRank = dstTp.getDimRank();
1879
1880 if (getInputs().size() <= 1)
1881 return emitError("Need at least two tensors to concatenate.");
1882
1883 if (concatDim >= dimRank)
1884 return emitError(llvm::formatv(
1885 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1886 concatDim, dimRank));
1887
1888 for (const auto &it : llvm::enumerate(getInputs())) {
1889 const auto i = it.index();
1890 const auto srcTp = getSparseTensorType(it.value());
1891 if (srcTp.hasDynamicDimShape())
1892 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1893 const Dimension srcDimRank = srcTp.getDimRank();
1894 if (srcDimRank != dimRank)
1895 return emitError(
1896 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1897 "from the output tensor (rank={2}).",
1898 i, srcDimRank, dimRank));
1899 }
1900
1901 for (Dimension d = 0; d < dimRank; d++) {
1902 const Size dstSh = dstTp.getDimShape()[d];
1903 if (d == concatDim) {
1904 if (ShapedType::isStatic(dstSh)) {
1905 // If we reach here, then all inputs have static shapes. So we
1906 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1907 // to avoid redundant assertions in the loop.
1908 Size sumSz = 0;
1909 for (const auto src : getInputs())
1910 sumSz += getSparseTensorType(src).getDimShape()[d];
1911 // If all dimension are statically known, the sum of all the input
1912 // dimensions should be equal to the output dimension.
1913 if (sumSz != dstSh)
1914 return emitError(
1915 "The concatenation dimension of the output tensor should be the "
1916 "sum of all the concatenation dimensions of the input tensors.");
1917 }
1918 } else {
1919 Size prev = dstSh;
1920 for (const auto src : getInputs()) {
1921 const auto sh = getSparseTensorType(src).getDimShape()[d];
1922 if (ShapedType::isStatic(prev) && sh != prev)
1923 return emitError("All dimensions (expect for the concatenating one) "
1924 "should be equal.");
1925 prev = sh;
1926 }
1927 }
1928 }
1929
1930 return success();
1931}
1932
1933void PushBackOp::build(OpBuilder &builder, OperationState &result,
1934 Value curSize, Value inBuffer, Value value) {
1935 build(builder, result, curSize, inBuffer, value, Value());
1936}
1937
1938LogicalResult PushBackOp::verify() {
1939 if (Value n = getN()) {
1940 std::optional<int64_t> nValue = getConstantIntValue(n);
1941 if (nValue && nValue.value() < 1)
1942 return emitOpError("n must be not less than 1");
1943 }
1944 return success();
1945}
1946
1947LogicalResult CompressOp::verify() {
1948 const auto stt = getSparseTensorType(getTensor());
1949 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1950 return emitOpError("incorrect number of coordinates");
1951 return success();
1952}
1953
1954void ForeachOp::build(
1955 OpBuilder &builder, OperationState &result, Value tensor,
1956 ValueRange initArgs, AffineMapAttr order,
1957 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1958 bodyBuilder) {
1959 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1960 // Builds foreach body.
1961 if (!bodyBuilder)
1962 return;
1963 const auto stt = getSparseTensorType(tensor);
1964 const Dimension dimRank = stt.getDimRank();
1965
1966 // Starts with `dimRank`-many coordinates.
1967 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1968 // Followed by one value.
1969 blockArgTypes.push_back(stt.getElementType());
1970 // Followed by the reduction variables.
1971 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1972
1973 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1974
1975 OpBuilder::InsertionGuard guard(builder);
1976 auto &region = *result.regions.front();
1977 Block *bodyBlock =
1978 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1979 bodyBuilder(builder, result.location,
1980 bodyBlock->getArguments().slice(0, dimRank),
1981 bodyBlock->getArguments()[dimRank],
1982 bodyBlock->getArguments().drop_front(dimRank + 1));
1983}
1984
1985LogicalResult ForeachOp::verify() {
1986 const auto t = getSparseTensorType(getTensor());
1987 const Dimension dimRank = t.getDimRank();
1988 const auto args = getBody()->getArguments();
1989
1990 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1991 return emitError("Level traverse order does not match tensor's level rank");
1992
1993 if (dimRank + 1 + getInitArgs().size() != args.size())
1994 return emitError("Unmatched number of arguments in the block");
1995
1996 if (getNumResults() != getInitArgs().size())
1997 return emitError("Mismatch in number of init arguments and results");
1998
1999 if (getResultTypes() != getInitArgs().getTypes())
2000 return emitError("Mismatch in types of init arguments and results");
2001
2002 // Cannot mark this const, because the getters aren't.
2003 auto yield = cast<YieldOp>(getBody()->getTerminator());
2004 if (yield.getNumOperands() != getNumResults() ||
2005 yield.getOperands().getTypes() != getResultTypes())
2006 return emitError("Mismatch in types of yield values and results");
2007
2008 const auto iTp = IndexType::get(getContext());
2009 for (Dimension d = 0; d < dimRank; d++)
2010 if (args[d].getType() != iTp)
2011 return emitError(
2012 llvm::formatv("Expecting Index type for argument at index {0}", d));
2013
2014 const auto elemTp = t.getElementType();
2015 const auto valueTp = args[dimRank].getType();
2016 if (elemTp != valueTp)
2017 return emitError(
2018 llvm::formatv("Unmatched element type between input tensor and "
2019 "block argument, expected:{0}, got: {1}",
2020 elemTp, valueTp));
2021 return success();
2022}
2023
2024OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2025 if (getSparseTensorEncoding(getInputCoo().getType()) ==
2026 getSparseTensorEncoding(getResultCoo().getType()))
2027 return getInputCoo();
2028
2029 return {};
2030}
2031
2032LogicalResult ReorderCOOOp::verify() {
2033 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2034 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2035
2036 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2037 return emitError("Expected COO sparse tensors only");
2038
2039 if (!srcStt.hasSameDimToLvl(dstStt))
2040 return emitError("Unmatched dim2lvl map between input and result COO");
2041
2042 if (srcStt.getPosType() != dstStt.getPosType() ||
2043 srcStt.getCrdType() != dstStt.getCrdType() ||
2044 srcStt.getElementType() != dstStt.getElementType())
2045 return emitError("Unmatched storage format between input and result COO");
2046
2047 return success();
2048}
2049
2050LogicalResult ReduceOp::verify() {
2051 Type inputType = getX().getType();
2052 Region &formula = getRegion();
2053 return verifyNumBlockArgs(this, formula, "reduce",
2054 TypeRange{inputType, inputType}, inputType);
2055}
2056
2057LogicalResult SelectOp::verify() {
2058 Builder b(getContext());
2059 Type inputType = getX().getType();
2060 Type boolType = b.getI1Type();
2061 Region &formula = getRegion();
2062 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2063 boolType);
2064}
2065
2066LogicalResult SortOp::verify() {
2067 AffineMap xPerm = getPermMap();
2068 uint64_t nx = xPerm.getNumDims();
2069 if (nx < 1)
2070 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2071
2072 if (!xPerm.isPermutation())
2073 return emitError(
2074 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2075
2076 // We can't check the size of the buffers when n or buffer dimensions aren't
2077 // compile-time constants.
2078 std::optional<int64_t> cn = getConstantIntValue(getN());
2079 if (!cn)
2080 return success();
2081
2082 // Verify dimensions.
2083 const auto checkDim = [&](Value v, Size minSize,
2084 const char *message) -> LogicalResult {
2085 const Size sh = getMemRefType(v).getShape()[0];
2086 if (ShapedType::isStatic(sh) && sh < minSize)
2087 return emitError(
2088 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2089 return success();
2090 };
2091 uint64_t n = cn.value();
2092 uint64_t ny = 0;
2093 if (auto nyAttr = getNyAttr())
2094 ny = nyAttr.getInt();
2095 if (failed(checkDim(getXy(), n * (nx + ny),
2096 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2097 return failure();
2098 for (Value opnd : getYs())
2099 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2100 return failure();
2101
2102 return success();
2103}
2104
2105//===----------------------------------------------------------------------===//
2106// Sparse Tensor Iteration Operations.
2107//===----------------------------------------------------------------------===//
2108
2109IterSpaceType IteratorType::getIterSpaceType() const {
2110 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2111 getHiLvl());
2112}
2113
2114IteratorType IterSpaceType::getIteratorType() const {
2115 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2116}
2117
2118/// Parses a level range in the form "$lo `to` $hi"
2119/// or simply "$lo" if $hi - $lo = 1
2120static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2121 Level &lvlHi) {
2122 if (parser.parseInteger(lvlLo))
2123 return failure();
2124
2125 if (succeeded(parser.parseOptionalKeyword("to"))) {
2126 if (parser.parseInteger(lvlHi))
2127 return failure();
2128 } else {
2129 lvlHi = lvlLo + 1;
2130 }
2131
2132 if (lvlHi <= lvlLo)
2133 return parser.emitError(parser.getNameLoc(),
2134 "expect larger level upper bound than lower bound");
2135
2136 return success();
2137}
2138
2139/// Parses a level range in the form "$lo `to` $hi"
2140/// or simply "$lo" if $hi - $lo = 1
2141static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2142 IntegerAttr &lvlHiAttr) {
2143 Level lvlLo, lvlHi;
2144 if (parseLevelRange(parser, lvlLo, lvlHi))
2145 return failure();
2146
2147 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2148 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2149 return success();
2150}
2151
2152/// Prints a level range in the form "$lo `to` $hi"
2153/// or simply "$lo" if $hi - $lo = 1
2154static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2155
2156 if (lo + 1 == hi)
2157 p << lo;
2158 else
2159 p << lo << " to " << hi;
2160}
2161
2162/// Prints a level range in the form "$lo `to` $hi"
2163/// or simply "$lo" if $hi - $lo = 1
2164static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2165 IntegerAttr lvlHi) {
2166 unsigned lo = lvlLo.getValue().getZExtValue();
2167 unsigned hi = lvlHi.getValue().getZExtValue();
2168 printLevelRange(p, lo, hi);
2169}
2170
2171/// Parses a list of `optional` defined list in the form of
2172/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2173/// corresponding value is not defined (e.g., to represent an undefined
2174/// coordinate in the sparse iteration space).
2175static ParseResult parseOptionalDefinedList(
2176 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2178 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2180 unsigned cnt = 0;
2181 ParseResult crdList =
2182 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2183 if (parser.parseOptionalKeyword("_")) {
2184 if (parser.parseArgument(definedArgs.emplace_back()))
2185 return failure();
2186 definedSet.set(cnt);
2187 }
2188 cnt += 1;
2189 return success();
2190 });
2191
2192 if (cnt > maxCnt)
2193 return parser.emitError(parser.getNameLoc(),
2194 "parsed more value than expected.");
2195
2196 if (failed(crdList)) {
2197 return parser.emitError(
2198 parser.getNameLoc(),
2199 "expecting SSA value or \"_\" for level coordinates");
2200 }
2201 assert(definedArgs.size() == definedSet.count());
2202 return success();
2203}
2204
2205static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2206 Block::BlockArgListType blocksArgs,
2207 I64BitSet definedSet) {
2208 if (definedSet.empty())
2209 return;
2210
2211 for (unsigned i = 0; i < size; i++) {
2212 if (definedSet[i]) {
2213 p << blocksArgs.front();
2214 blocksArgs = blocksArgs.drop_front();
2215 } else {
2216 p << "_";
2217 }
2218 if (i != size - 1)
2219 p << ", ";
2220 }
2221 assert(blocksArgs.empty());
2222}
2223
2224static ParseResult
2227 // Parse "at(%crd0, _, ...)"
2228 I64BitSet crdUsedLvlSet;
2229 if (succeeded(parser.parseOptionalKeyword("at")) &&
2230 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2231 return failure();
2232
2233 // Always use IndexType for the coordinate.
2234 for (auto &coord : coords)
2235 coord.type = parser.getBuilder().getIndexType();
2236
2237 // Set the CrdUsedLvl bitset.
2238 state.addAttribute("crdUsedLvls",
2239 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2240 return success();
2241}
2242
2243static ParseResult
2249
2250 // Parse "%iters, ... in %spaces, ..."
2251 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2252 parser.parseOperandList(spaces))
2253 return failure();
2254
2255 if (iterators.size() != spaces.size())
2256 return parser.emitError(
2257 parser.getNameLoc(),
2258 "mismatch in number of sparse iterators and sparse spaces");
2259
2261 if (failed(parseUsedCoordList(parser, state, coords)))
2262 return failure();
2263 size_t numCrds = coords.size();
2264
2265 // Parse "iter_args(%arg = %init, ...)"
2266 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2267 if (hasIterArgs)
2268 if (parser.parseAssignmentList(blockArgs, initArgs))
2269 return failure();
2270
2271 blockArgs.append(coords);
2272
2273 SmallVector<Type> iterSpaceTps;
2274 // parse ": sparse_tensor.iter_space -> ret"
2275 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2276 return failure();
2277 if (iterSpaceTps.size() != spaces.size())
2278 return parser.emitError(parser.getNameLoc(),
2279 "mismatch in number of iteration space operands "
2280 "and iteration space types");
2281
2282 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2283 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2284 if (!spaceTp)
2285 return parser.emitError(parser.getNameLoc(),
2286 "expected sparse_tensor.iter_space type for "
2287 "iteration space operands");
2288 it.type = spaceTp.getIteratorType();
2289 }
2290
2291 if (hasIterArgs)
2292 if (parser.parseArrowTypeList(state.types))
2293 return failure();
2294
2295 // Resolves input operands.
2296 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2297 state.operands))
2298 return failure();
2299
2300 if (hasIterArgs) {
2301 // Strip off leading args that used for coordinates.
2302 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2303 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2304 return parser.emitError(
2305 parser.getNameLoc(),
2306 "mismatch in number of iteration arguments and return values");
2307 }
2308
2309 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2310 it.type = tp;
2311 if (parser.resolveOperand(init, tp, state.operands))
2312 return failure();
2313 }
2314 }
2315 return success();
2316}
2317
2318static ParseResult
2320 SmallVectorImpl<Value> &spacesVals,
2322
2323 // Parse "(%spaces, ...)"
2326 return failure();
2327
2329 if (failed(parseUsedCoordList(parser, state, coords)))
2330 return failure();
2331 size_t numCrds = coords.size();
2332
2333 // Parse "iter_args(%arg = %init, ...)"
2335 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2336 if (hasIterArgs)
2337 if (parser.parseAssignmentList(blockArgs, initArgs))
2338 return failure();
2339 blockArgs.append(coords);
2340
2341 SmallVector<Type> iterSpaceTps;
2342 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2343 if (parser.parseColon() || parser.parseLParen() ||
2344 parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2345 return failure();
2346
2347 if (iterSpaceTps.size() != spaces.size())
2348 return parser.emitError(parser.getNameLoc(),
2349 "mismatch in number of iteration space operands "
2350 "and iteration space types");
2351
2352 if (hasIterArgs)
2353 if (parser.parseArrowTypeList(state.types))
2354 return failure();
2355
2356 // Resolves input sparse iteration spaces.
2357 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2358 spacesVals))
2359 return failure();
2360 state.operands.append(spacesVals);
2361
2362 if (hasIterArgs) {
2363 // Strip off trailing args that used for coordinates.
2364 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2365 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2366 return parser.emitError(
2367 parser.getNameLoc(),
2368 "mismatch in number of iteration arguments and return values");
2369 }
2370
2371 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2372 it.type = tp;
2373 if (parser.resolveOperand(init, tp, state.operands))
2374 return failure();
2375 }
2376 }
2377 return success();
2378}
2379
2380LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2381 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2382 DictionaryAttr attr, PropertyRef prop, RegionRange region,
2383 SmallVectorImpl<mlir::Type> &ret) {
2384
2385 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2386 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2387 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2388 adaptor.getHiLvl()));
2389 return success();
2390}
2391
2392LogicalResult ExtractIterSpaceOp::verify() {
2393 if (getLoLvl() >= getHiLvl())
2394 return emitOpError("expected smaller level low than level high");
2395
2396 TypedValue<IteratorType> pIter = getParentIter();
2397 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2398 return emitOpError(
2399 "parent iterator should be specified iff level lower bound equals 0");
2400 }
2401
2402 if (pIter) {
2403 IterSpaceType spaceTp = getExtractedSpace().getType();
2404 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2405 return emitOpError(
2406 "mismatch in parent iterator encoding and iteration space encoding.");
2407
2408 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2409 return emitOpError("parent iterator should be used to extract an "
2410 "iteration space from a consecutive level.");
2411 }
2412
2413 return success();
2414}
2415
2416LogicalResult ExtractValOp::verify() {
2417 auto stt = getSparseTensorType(getTensor());
2418 auto itTp = getIterator().getType();
2419
2420 if (stt.getEncoding() != itTp.getEncoding())
2421 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2422
2423 if (stt.getLvlRank() != itTp.getHiLvl())
2424 return emitOpError("must use last-level iterator to extract values. ");
2425
2426 return success();
2427}
2428
2429struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2431
2432 LogicalResult matchAndRewrite(IterateOp iterateOp,
2433 PatternRewriter &rewriter) const override {
2434 I64BitSet newUsedLvls(0);
2435 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2436 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2437 if (auto crd = iterateOp.getLvlCrd(i)) {
2438 if (crd->getUsers().empty())
2439 toRemove.set(crd->getArgNumber());
2440 else
2441 newUsedLvls.set(i);
2442 }
2443 }
2444
2445 // All coordinates are used.
2446 if (toRemove.none())
2447 return failure();
2448
2449 rewriter.startOpModification(iterateOp);
2450 iterateOp.setCrdUsedLvls(newUsedLvls);
2451 iterateOp.getBody()->eraseArguments(toRemove);
2452 rewriter.finalizeOpModification(iterateOp);
2453 return success();
2454 }
2455};
2456
2457void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2458 mlir::MLIRContext *context) {
2459 results.add<RemoveUnusedLvlCrds>(context);
2460}
2461
2462void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2463 Value iterSpace, ValueRange initArgs) {
2464 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2465 // All ones.
2466 I64BitSet set((1 << rank) - 1);
2467 return build(builder, odsState, iterSpace, initArgs, set);
2468}
2469
2470void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2471 Value iterSpace, ValueRange initArgs,
2472 I64BitSet crdUsedLvls) {
2473 OpBuilder::InsertionGuard guard(builder);
2474
2475 odsState.addOperands(iterSpace);
2476 odsState.addOperands(initArgs);
2477 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2478 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2479 Region *bodyRegion = odsState.addRegion();
2480 odsState.addTypes(initArgs.getTypes());
2481 Block *bodyBlock = builder.createBlock(bodyRegion);
2482
2483 // Starts with a list of user-provided loop arguments.
2484 for (Value v : initArgs)
2485 bodyBlock->addArgument(v.getType(), v.getLoc());
2486
2487 // Follows by a list of used coordinates.
2488 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2489 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2490
2491 // Ends with sparse iterator
2492 bodyBlock->addArgument(
2493 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2494 odsState.location);
2495}
2496
2497ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2498 OpAsmParser::Argument iterator;
2499 OpAsmParser::UnresolvedOperand iterSpace;
2500
2501 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2502 if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2503 return failure();
2504 if (iters.size() != 1)
2505 return parser.emitError(parser.getNameLoc(),
2506 "expected only one iterator/iteration space");
2507
2508 iterArgs.append(iters);
2509 Region *body = result.addRegion();
2510 if (parser.parseRegion(*body, iterArgs))
2511 return failure();
2512
2513 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2514
2515 // Parse the optional attribute list.
2516 if (parser.parseOptionalAttrDict(result.attributes))
2517 return failure();
2518
2519 return success();
2520}
2521
2522/// Prints the initialization list in the form of
2523/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2524/// where 'inner' values are assumed to be region arguments and 'outer' values
2525/// are regular SSA values.
2527 Block::BlockArgListType blocksArgs,
2528 ValueRange initializers,
2529 StringRef prefix = "") {
2530 assert(blocksArgs.size() == initializers.size() &&
2531 "expected same length of arguments and initializers");
2532 if (initializers.empty())
2533 return;
2534
2535 p << prefix << '(';
2536 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2537 p << std::get<0>(it) << " = " << std::get<1>(it);
2538 });
2539 p << ")";
2540}
2541
2542template <typename SparseLoopOp>
2543static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2544 if (op.getInitArgs().size() != op.getNumResults()) {
2545 return op.emitOpError(
2546 "mismatch in number of loop-carried values and defined values");
2547 }
2548 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2549 return op.emitOpError("required out-of-bound coordinates");
2550
2551 return success();
2552}
2553
2554LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2555LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2556
2557void IterateOp::print(OpAsmPrinter &p) {
2558 p << " " << getIterator() << " in " << getIterSpace();
2559 if (!getCrdUsedLvls().empty()) {
2560 p << " at(";
2561 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2562 p << ")";
2563 }
2564 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2565
2566 p << " : " << getIterSpace().getType() << " ";
2567 if (!getInitArgs().empty())
2568 p.printArrowTypeList(getInitArgs().getTypes());
2569
2570 p << " ";
2571 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2572 /*printBlockTerminators=*/!getInitArgs().empty());
2573}
2574
2575LogicalResult IterateOp::verifyRegions() {
2576 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2577 return emitOpError("mismatch in iterator and iteration space type");
2578 if (getNumRegionIterArgs() != getNumResults())
2579 return emitOpError(
2580 "mismatch in number of basic block args and defined values");
2581
2582 auto initArgs = getInitArgs();
2583 auto iterArgs = getRegionIterArgs();
2584 auto yieldVals = getYieldedValues();
2585 auto opResults = getResults();
2586 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2587 opResults.size()})) {
2588 return emitOpError() << "number mismatch between iter args and results.";
2589 }
2590
2591 for (auto [i, init, iter, yield, ret] :
2592 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2593 if (init.getType() != ret.getType())
2594 return emitOpError() << "types mismatch between " << i
2595 << "th iter operand and defined value";
2596 if (iter.getType() != ret.getType())
2597 return emitOpError() << "types mismatch between " << i
2598 << "th iter region arg and defined value";
2599 if (yield.getType() != ret.getType())
2600 return emitOpError() << "types mismatch between " << i
2601 << "th yield value and defined value";
2602 }
2603
2604 return success();
2605}
2606
2607/// OpInterfaces' methods implemented by IterateOp.
2608SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2609
2610MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2611 return getInitArgsMutable();
2612}
2613
2614Block::BlockArgListType IterateOp::getRegionIterArgs() {
2615 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2616}
2617
2618std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2619 return cast<sparse_tensor::YieldOp>(
2620 getRegion().getBlocks().front().getTerminator())
2621 .getResultsMutable();
2622}
2623
2624std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2625
2626OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2627 return getInitArgs();
2628}
2629
2630void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2631 SmallVectorImpl<RegionSuccessor> &regions) {
2632 // Both the operation itself and the region may be branching into the body
2633 // or back into the operation itself.
2634 regions.push_back(RegionSuccessor(&getRegion()));
2635 // It is possible for loop not to enter the body.
2636 regions.push_back(RegionSuccessor::parent());
2637}
2638
2639ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2640 return successor.isParent() ? ValueRange(getResults())
2641 : ValueRange(getRegionIterArgs());
2642}
2643
2644void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2645 ValueRange iterSpaces, ValueRange initArgs,
2646 unsigned numCases) {
2647 unsigned rank =
2648 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2649 // All ones.
2650 I64BitSet set((1 << rank) - 1);
2651 // Generates all-zero case bits (they only serve as placeholders), which are
2652 // supposed to be overriden later. We need to preallocate all the regions as
2653 // mlir::Region cannot be dynamically added later after the operation is
2654 // created.
2655 SmallVector<int64_t> caseBits(numCases, 0);
2656 ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2657 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2658 initArgs, set, cases,
2659 /*caseRegionsCount=*/numCases);
2660}
2661
2662ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2663
2664 SmallVector<Value> spaces;
2665 // The block argument list of each regions, it is arranged in the order of
2666 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2667 SmallVector<OpAsmParser::Argument> blockArgs;
2668 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2669 return failure();
2670
2671 result.addAttribute("operandSegmentSizes",
2673 {static_cast<int32_t>(spaces.size()),
2674 static_cast<int32_t>(result.types.size())}));
2675
2676 SmallVector<Attribute> cases;
2677 while (succeeded(parser.parseOptionalKeyword("case"))) {
2678 // Parse one region per case.
2679 I64BitSet definedItSet;
2680 SmallVector<OpAsmParser::Argument> definedIts;
2681 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2682 spaces.size(), OpAsmParser::Delimiter::None))
2683 return failure();
2684
2685 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2686
2687 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2688 // Resolve the iterator type based on the iteration space type.
2689 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2690 definedIts[i].type = spaceTp.getIteratorType();
2691 }
2692 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2693 Region *body = result.addRegion();
2694 if (parser.parseRegion(*body, definedIts))
2695 return failure();
2696
2697 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2698 }
2699
2700 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2701
2702 // Parse the optional attribute list.
2703 if (parser.parseOptionalAttrDict(result.attributes))
2704 return failure();
2705
2706 return success();
2707}
2708
2709void CoIterateOp::print(OpAsmPrinter &p) {
2710 p << " (";
2711 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2712 p << ")";
2713
2714 if (!getCrdUsedLvls().empty()) {
2715 p << " at(";
2716 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2717 p << ")";
2718 }
2719
2720 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2721
2722 p << " : (" << getIterSpaces().getTypes() << ")";
2723 if (!getInitArgs().empty())
2724 p.printArrowTypeList(getInitArgs().getTypes());
2725
2726 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2727 p.printNewline();
2728 p << "case ";
2729 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2730 getRegionDefinedSpace(idx));
2731 p << " ";
2732 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2733 /*printBlockTerminators=*/!getInitArgs().empty());
2734 }
2735}
2736
2737ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2738 return cast<sparse_tensor::YieldOp>(
2739 getRegion(regionIdx).getBlocks().front().getTerminator())
2740 .getResults();
2741}
2742
2743LogicalResult CoIterateOp::verifyRegions() {
2744 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2745 if (getNumRegionIterArgs() != getNumResults())
2746 return emitOpError(
2747 "mismatch in number of basic block args and defined values");
2748
2749 auto initArgs = getInitArgs();
2750 auto iterArgs = getRegionIterArgs(r);
2751 auto yieldVals = getYieldedValues(r);
2752 auto opResults = getResults();
2753 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2754 opResults.size()})) {
2755 return emitOpError()
2756 << "number mismatch between iter args and results on " << r
2757 << "th region";
2758 }
2759
2760 for (auto [i, init, iter, yield, ret] :
2761 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2762 if (init.getType() != ret.getType())
2763 return emitOpError()
2764 << "types mismatch between " << i
2765 << "th iter operand and defined value on " << r << "th region";
2766 if (iter.getType() != ret.getType())
2767 return emitOpError() << "types mismatch between " << i
2768 << "th iter region arg and defined value on " << r
2769 << "th region";
2770 if (yield.getType() != ret.getType())
2771 return emitOpError()
2772 << "types mismatch between " << i
2773 << "th yield value and defined value on " << r << "th region";
2774 }
2775 }
2776
2777 auto cases = getRegionDefinedSpaces();
2778 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2779 if (set.size() != getNumRegions())
2780 return emitOpError("contains duplicated cases.");
2781
2782 return success();
2783}
2784
2785SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2786 SmallVector<Region *> ret;
2787 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2788 for (Region &r : getCaseRegions())
2789 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2790 ret.push_back(&r);
2791
2792 return ret;
2793}
2794
2795//===----------------------------------------------------------------------===//
2796// Sparse Tensor Dialect Setups.
2797//===----------------------------------------------------------------------===//
2798
2799/// Materialize a single constant operation from a given attribute value with
2800/// the desired resultant type.
2801Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2802 Attribute value, Type type,
2803 Location loc) {
2804 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2805 return op;
2806 return nullptr;
2807}
2808
2809void SparseTensorDialect::initialize() {
2810 addAttributes<
2811#define GET_ATTRDEF_LIST
2812#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2813 >();
2814 addTypes<
2815#define GET_TYPEDEF_LIST
2816#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2817 >();
2818 addOperations<
2819#define GET_OP_LIST
2820#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2821 >();
2822 declarePromisedInterfaces<
2823 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2824 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2825 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2826}
2827
2828#define GET_OP_CLASSES
2829#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2830
2831#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, PropertyRef prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition Enums.h:402
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
bool isOrderedLT(LevelType lt)
Definition Enums.h:425
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
Definition Enums.h:442
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition Enums.h:413
uint64_t getM(LevelType lt)
Definition Enums.h:443
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition Enums.h:299