MLIR  16.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
14 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/Matchers.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/FormatVariadic.h"
20 
21 using namespace mlir;
22 using namespace mlir::sparse_tensor;
23 
24 //===----------------------------------------------------------------------===//
25 // TensorDialect Attribute Methods.
26 //===----------------------------------------------------------------------===//
27 
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 
31 static bool acceptBitWidth(unsigned bitWidth) {
32  switch (bitWidth) {
33  case 0:
34  case 8:
35  case 16:
36  case 32:
37  case 64:
38  return true;
39  default:
40  return false;
41  }
42 }
43 
44 Type SparseTensorEncodingAttr::getPointerType() const {
45  unsigned ptrWidth = getPointerBitWidth();
46  Type indexType = IndexType::get(getContext());
47  return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
48 }
49 
50 Type SparseTensorEncodingAttr::getIndexType() const {
51  unsigned idxWidth = getIndexBitWidth();
52  Type indexType = IndexType::get(getContext());
53  return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
54 }
55 
56 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
57  if (failed(parser.parseLess()))
58  return {};
59  // Parse the data as a dictionary.
60  DictionaryAttr dict;
61  if (failed(parser.parseAttribute(dict)))
62  return {};
63  if (failed(parser.parseGreater()))
64  return {};
65  // Process the data from the parsed dictionary value into struct-like data.
67  AffineMap dimOrd = {};
68  AffineMap higherOrd = {};
69  unsigned ptr = 0;
70  unsigned ind = 0;
71  for (const NamedAttribute &attr : dict) {
72  if (attr.getName() == "dimLevelType") {
73  auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
74  if (!arrayAttr) {
75  parser.emitError(parser.getNameLoc(),
76  "expected an array for dimension level types");
77  return {};
78  }
79  for (auto i : arrayAttr) {
80  auto strAttr = i.dyn_cast<StringAttr>();
81  if (!strAttr) {
82  parser.emitError(parser.getNameLoc(),
83  "expected a string value in dimension level types");
84  return {};
85  }
86  auto strVal = strAttr.getValue();
87  if (strVal == "dense") {
88  dlt.push_back(DimLevelType::Dense);
89  } else if (strVal == "compressed") {
90  dlt.push_back(DimLevelType::Compressed);
91  } else if (strVal == "compressed-nu") {
92  dlt.push_back(DimLevelType::CompressedNu);
93  } else if (strVal == "compressed-no") {
94  dlt.push_back(DimLevelType::CompressedNo);
95  } else if (strVal == "compressed-nu-no") {
96  dlt.push_back(DimLevelType::CompressedNuNo);
97  } else if (strVal == "singleton") {
98  dlt.push_back(DimLevelType::Singleton);
99  } else if (strVal == "singleton-nu") {
100  dlt.push_back(DimLevelType::SingletonNu);
101  } else if (strVal == "singleton-no") {
102  dlt.push_back(DimLevelType::SingletonNo);
103  } else if (strVal == "singleton-nu-no") {
104  dlt.push_back(DimLevelType::SingletonNuNo);
105  } else {
106  parser.emitError(parser.getNameLoc(),
107  "unexpected dimension level type: ")
108  << strVal;
109  return {};
110  }
111  }
112  } else if (attr.getName() == "dimOrdering") {
113  auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
114  if (!affineAttr) {
115  parser.emitError(parser.getNameLoc(),
116  "expected an affine map for dimension ordering");
117  return {};
118  }
119  dimOrd = affineAttr.getValue();
120  } else if (attr.getName() == "higherOrdering") {
121  auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
122  if (!affineAttr) {
123  parser.emitError(parser.getNameLoc(),
124  "expected an affine map for higher ordering");
125  return {};
126  }
127  higherOrd = affineAttr.getValue();
128  } else if (attr.getName() == "pointerBitWidth") {
129  auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
130  if (!intAttr) {
131  parser.emitError(parser.getNameLoc(),
132  "expected an integral pointer bitwidth");
133  return {};
134  }
135  ptr = intAttr.getInt();
136  } else if (attr.getName() == "indexBitWidth") {
137  auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
138  if (!intAttr) {
139  parser.emitError(parser.getNameLoc(),
140  "expected an integral index bitwidth");
141  return {};
142  }
143  ind = intAttr.getInt();
144  } else {
145  parser.emitError(parser.getNameLoc(), "unexpected key: ")
146  << attr.getName().strref();
147  return {};
148  }
149  }
150  // Construct struct-like storage for attribute.
151  return parser.getChecked<SparseTensorEncodingAttr>(
152  parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind);
153 }
154 
155 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
156  // Print the struct-like storage in dictionary fashion.
157  printer << "<{ dimLevelType = [ ";
158  for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
159  switch (getDimLevelType()[i]) {
160  case DimLevelType::Undef:
161  // TODO: should probably raise an error instead of printing it...
162  printer << "\"undef\"";
163  break;
164  case DimLevelType::Dense:
165  printer << "\"dense\"";
166  break;
168  printer << "\"compressed\"";
169  break;
171  printer << "\"compressed-nu\"";
172  break;
174  printer << "\"compressed-no\"";
175  break;
177  printer << "\"compressed-nu-no\"";
178  break;
180  printer << "\"singleton\"";
181  break;
183  printer << "\"singleton-nu\"";
184  break;
186  printer << "\"singleton-no\"";
187  break;
189  printer << "\"singleton-nu-no\"";
190  break;
191  }
192  if (i != e - 1)
193  printer << ", ";
194  }
195  printer << " ]";
196  // Print remaining members only for non-default values.
197  if (getDimOrdering() && !getDimOrdering().isIdentity())
198  printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
199  if (getHigherOrdering())
200  printer << ", higherOrdering = affine_map<" << getHigherOrdering() << ">";
201  if (getPointerBitWidth())
202  printer << ", pointerBitWidth = " << getPointerBitWidth();
203  if (getIndexBitWidth())
204  printer << ", indexBitWidth = " << getIndexBitWidth();
205  printer << " }>";
206 }
207 
210  ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
211  AffineMap higherOrdering, unsigned pointerBitWidth,
212  unsigned indexBitWidth) {
213  if (!acceptBitWidth(pointerBitWidth))
214  return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
215  if (!acceptBitWidth(indexBitWidth))
216  return emitError() << "unexpected index bitwidth: " << indexBitWidth;
217  if (dimOrdering) {
218  if (!dimOrdering.isPermutation())
219  return emitError()
220  << "expected a permutation affine map for dimension ordering";
221  if (dimOrdering.getNumResults() != dimLevelType.size())
222  return emitError() << "unexpected mismatch in ordering and dimension "
223  "level types size";
224  }
225  if (higherOrdering) {
226  if (higherOrdering.getNumDims() >= higherOrdering.getNumResults())
227  return emitError() << "unexpected higher ordering mapping from "
228  << higherOrdering.getNumDims() << " to "
229  << higherOrdering.getNumResults();
230  if (higherOrdering.getNumResults() != dimLevelType.size())
231  return emitError() << "unexpected mismatch in higher ordering and "
232  "dimension level types size";
233  }
234  return success();
235 }
236 
237 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
238  ArrayRef<int64_t> shape, Type elementType,
240  // Check structural integrity.
241  if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
242  getHigherOrdering(), getPointerBitWidth(),
243  getIndexBitWidth())))
244  return failure();
245  // Check integrity with tensor type specifics. Dimension ordering is optional,
246  // but we always should have dimension level types for the full rank.
247  unsigned size = shape.size();
248  if (size == 0)
249  return emitError() << "expected non-scalar sparse tensor";
250  if (getHigherOrdering()) {
251  if (getHigherOrdering().getNumDims() != size)
252  return emitError() << "expected an affine map of size " << size
253  << " for higher ordering";
254 
255  // TODO: verification of higher ordering contents
256 
257  size = getHigherOrdering().getNumResults(); // higher-order size!
258  }
259  if (getDimOrdering() && getDimOrdering().getNumResults() != size)
260  return emitError() << "expected an affine map of size " << size
261  << " for dimension ordering";
262  if (getDimLevelType().size() != size)
263  return emitError() << "expected an array of size " << size
264  << " for dimension level types";
265  return success();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Convenience Methods.
270 //===----------------------------------------------------------------------===//
271 
272 SparseTensorEncodingAttr
274  if (auto ttp = type.dyn_cast<RankedTensorType>())
275  return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
276  return nullptr;
277 }
278 
279 bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
280  SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
281 
282  if (!enc)
283  return false;
284 
285  if (!isCompressedDim(tp, 0))
286  return false;
287 
288  for (uint64_t i = 1, e = tp.getRank(); i < e; ++i)
289  if (!isSingletonDim(tp, i))
290  return false;
291 
292  // This works for rank == 1 (unique the only compressed) and rank > 1 (unique
293  // on the last singleton).
294  return isUniqueDim(tp, tp.getRank() - 1);
295 }
296 
297 uint64_t mlir::sparse_tensor::toOrigDim(const SparseTensorEncodingAttr &enc,
298  uint64_t d) {
299  if (enc) {
300  auto order = enc.getDimOrdering();
301  if (order) {
302  assert(order.isPermutation());
303  return order.getDimPosition(d);
304  }
305  }
306  return d;
307 }
308 
309 uint64_t mlir::sparse_tensor::toStoredDim(const SparseTensorEncodingAttr &enc,
310  uint64_t d) {
311  if (enc) {
312  auto order = enc.getDimOrdering();
313  if (order) {
314  assert(order.isPermutation());
315  auto maybePos =
316  order.getResultPosition(getAffineDimExpr(d, enc.getContext()));
317  assert(maybePos.has_value());
318  return *maybePos;
319  }
320  }
321  return d;
322 }
323 
324 uint64_t mlir::sparse_tensor::toOrigDim(RankedTensorType type, uint64_t d) {
325  assert(d < static_cast<uint64_t>(type.getRank()));
326  return toOrigDim(getSparseTensorEncoding(type), d);
327 }
328 
329 uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
330  assert(d < static_cast<uint64_t>(type.getRank()));
331  return toStoredDim(getSparseTensorEncoding(type), d);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // TensorDialect Operations.
336 //===----------------------------------------------------------------------===//
337 
338 static LogicalResult isInBounds(uint64_t dim, Value tensor) {
339  uint64_t rank = tensor.getType().cast<RankedTensorType>().getRank();
340  if (dim >= rank)
341  return failure();
342  return success(); // in bounds
343 }
344 
345 static LogicalResult isMatchingWidth(Value result, unsigned width) {
346  Type etp = result.getType().cast<MemRefType>().getElementType();
347  if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
348  return success();
349  return failure();
350 }
351 
353  if (getExpandSymmetry() &&
354  getResult().getType().cast<RankedTensorType>().getRank() != 2)
355  return emitOpError("expand_symmetry can only be used for 2D tensors");
356  return success();
357 }
358 
360  if (auto tp1 = getSource().getType().dyn_cast<RankedTensorType>()) {
361  if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) {
362  if (tp1.getRank() != tp2.getRank())
363  return emitError("unexpected conversion mismatch in rank");
364  auto shape1 = tp1.getShape();
365  auto shape2 = tp2.getShape();
366  // Accept size matches between the source and the destination type
367  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
368  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
369  for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
370  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
371  return emitError("unexpected conversion mismatch in dimension ") << d;
372  return success();
373  }
374  }
375  return emitError("unexpected type in convert");
376 }
377 
378 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
379  Type dstType = getType();
380  // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
381  // convert for codegen to remove. This is because we use trivial
382  // sparse-to-sparse convert to tell bufferization that the sparse codegen
383  // will expand the tensor buffer into sparse tensor storage.
384  if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType())
385  return getSource();
386  return {};
387 }
388 
390  auto e = getSparseTensorEncoding(getTensor().getType());
391  if (failed(isInBounds(getDimension().getZExtValue(), getTensor())))
392  return emitError("requested pointers dimension out of bounds");
393  if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth())))
394  return emitError("unexpected type for pointers");
395  return success();
396 }
397 
399  auto e = getSparseTensorEncoding(getTensor().getType());
400  if (failed(isInBounds(getDimension().getZExtValue(), getTensor())))
401  return emitError("requested indices dimension out of bounds");
402  if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth())))
403  return emitError("unexpected type for indices");
404  return success();
405 }
406 
408  RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
409  MemRefType mtp = getResult().getType().cast<MemRefType>();
410  if (ttp.getElementType() != mtp.getElementType())
411  return emitError("unexpected mismatch in element types");
412  return success();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // TensorDialect Linalg.Generic Operations.
417 //===----------------------------------------------------------------------===//
418 
419 template <class T>
421  const char *regionName,
422  TypeRange inputTypes, Type outputType) {
423  unsigned numArgs = region.getNumArguments();
424  unsigned expectedNum = inputTypes.size();
425  if (numArgs != expectedNum)
426  return op->emitError() << regionName << " region must have exactly "
427  << expectedNum << " arguments";
428 
429  for (unsigned i = 0; i < numArgs; i++) {
430  Type typ = region.getArgument(i).getType();
431  if (typ != inputTypes[i])
432  return op->emitError() << regionName << " region argument " << (i + 1)
433  << " type mismatch";
434  }
435  Operation *term = region.front().getTerminator();
436  YieldOp yield = dyn_cast<YieldOp>(term);
437  if (!yield)
438  return op->emitError() << regionName
439  << " region must end with sparse_tensor.yield";
440  if (!yield.getResult() || yield.getResult().getType() != outputType)
441  return op->emitError() << regionName << " region yield type mismatch";
442 
443  return success();
444 }
445 
447  NamedAttrList attrs = (*this)->getAttrs();
448  Type leftType = getX().getType();
449  Type rightType = getY().getType();
450  Type outputType = getOutput().getType();
451  Region &overlap = getOverlapRegion();
452  Region &left = getLeftRegion();
453  Region &right = getRightRegion();
454 
455  // Check correct number of block arguments and return type for each
456  // non-empty region.
457  LogicalResult regionResult = success();
458  if (!overlap.empty()) {
459  regionResult = verifyNumBlockArgs(
460  this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
461  if (failed(regionResult))
462  return regionResult;
463  }
464  if (!left.empty()) {
465  regionResult =
466  verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
467  if (failed(regionResult))
468  return regionResult;
469  } else if (getLeftIdentity()) {
470  if (leftType != outputType)
471  return emitError("left=identity requires first argument to have the same "
472  "type as the output");
473  }
474  if (!right.empty()) {
475  regionResult = verifyNumBlockArgs(this, right, "right",
476  TypeRange{rightType}, outputType);
477  if (failed(regionResult))
478  return regionResult;
479  } else if (getRightIdentity()) {
480  if (rightType != outputType)
481  return emitError("right=identity requires second argument to have the "
482  "same type as the output");
483  }
484 
485  return success();
486 }
487 
489  Type inputType = getX().getType();
490  Type outputType = getOutput().getType();
491  LogicalResult regionResult = success();
492 
493  // Check correct number of block arguments and return type for each
494  // non-empty region.
495  Region &present = getPresentRegion();
496  if (!present.empty()) {
497  regionResult = verifyNumBlockArgs(this, present, "present",
498  TypeRange{inputType}, outputType);
499  if (failed(regionResult))
500  return regionResult;
501  }
502  Region &absent = getAbsentRegion();
503  if (!absent.empty()) {
504  regionResult =
505  verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
506  if (failed(regionResult))
507  return regionResult;
508  }
509 
510  return success();
511 }
512 
514  auto dstTp = getType().cast<RankedTensorType>();
515  uint64_t concatDim = getDimension().getZExtValue();
516  unsigned rank = dstTp.getRank();
517 
518  if (getInputs().size() <= 1)
519  return emitError("Need at least two tensors to concatenate.");
520 
521  for (auto type : getInputs().getTypes()) {
522  auto shape = type.cast<RankedTensorType>().getShape();
523  for (auto dim : shape) {
524  if (ShapedType::isDynamic(dim))
525  return emitError("Only statically-sized input tensors are supported.");
526  }
527  }
528 
529  if (concatDim >= rank)
530  return emitError(llvm::formatv(
531  "Failed to concatentate tensors with rank={0} on dimension={1}.", rank,
532  concatDim));
533 
534  for (size_t i = 0, e = getInputs().size(); i < e; i++) {
535  Value input = getInputs()[i];
536  auto inputRank = input.getType().cast<RankedTensorType>().getRank();
537  if (inputRank != rank)
538  return emitError(
539  llvm::formatv("The input tensor ${0} has a different rank (rank={1}) "
540  "from the output tensor (rank={2}).",
541  i, inputRank, rank));
542  }
543 
544  for (unsigned i = 0; i < rank; i++) {
545  auto dstDim = dstTp.getShape()[i];
546  if (i == concatDim) {
547  if (!ShapedType::isDynamic(dstDim)) {
548  unsigned sumDim = 0;
549  for (auto src : getInputs()) {
550  // If we reach here, all inputs should have static shapes.
551  auto d = src.getType().cast<RankedTensorType>().getShape()[i];
552  sumDim += d;
553  }
554  // If all dimension are statically known, the sum of all the input
555  // dimensions should be equal to the output dimension.
556  if (sumDim != dstDim)
557  return emitError(
558  "The concatenation dimension of the output tensor should be the "
559  "sum of all the concatenation dimensions of the input tensors.");
560  }
561  } else {
562  int64_t prev = dstDim;
563  for (auto src : getInputs()) {
564  auto d = src.getType().cast<RankedTensorType>().getShape()[i];
565  if (!ShapedType::isDynamic(prev) && d != prev)
566  return emitError("All dimensions (expect for the concatenating one) "
567  "should be equal.");
568  prev = d;
569  }
570  }
571  }
572 
573  return success();
574 }
575 
577  RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
578  if (ttp.getRank() != static_cast<int64_t>(getIndices().size()))
579  return emitOpError("incorrect number of indices");
580  return success();
581 }
582 
583 void PushBackOp::build(OpBuilder &builder, OperationState &result,
584  Type outBuffer, Value bufferSizes, Value inBuffer,
585  Value value, APInt idx) {
586  build(builder, result, outBuffer, bufferSizes, inBuffer, value,
587  std::move(idx), Value());
588 }
589 
591  Value n = getN();
592  if (n) {
593  auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
594  if (nValue && nValue.value() < 1)
595  return emitOpError("n must be not less than 1");
596  }
597  return success();
598 }
599 
601  RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
602  if (ttp.getRank() != 1 + static_cast<int64_t>(getIndices().size()))
603  return emitOpError("incorrect number of indices");
604  return success();
605 }
606 
607 void ForeachOp::build(
608  OpBuilder &builder, OperationState &result, Value tensor,
610  bodyBuilder) {
611  build(builder, result, tensor, std::nullopt, bodyBuilder);
612 }
613 
614 void ForeachOp::build(
615  OpBuilder &builder, OperationState &result, Value tensor,
616  ValueRange initArgs,
618  bodyBuilder) {
619  build(builder, result, initArgs.getTypes(), tensor, initArgs);
620  // Builds foreach body.
621  if (!bodyBuilder)
622  return;
623  auto rtp = tensor.getType().cast<RankedTensorType>();
624  int64_t rank = rtp.getRank();
625 
626  SmallVector<Type> blockArgTypes;
627  // Starts with n index.
628  std::fill_n(std::back_inserter(blockArgTypes), rank, builder.getIndexType());
629  // Followed by one value.
630  blockArgTypes.push_back(rtp.getElementType());
631  // Followed by reduction variable.
632  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
633 
634  SmallVector<Location> blockArgLocs;
635  std::fill_n(std::back_inserter(blockArgLocs), blockArgTypes.size(),
636  tensor.getLoc());
637 
638  OpBuilder::InsertionGuard guard(builder);
639  auto &region = *result.regions.front();
640  Block *bodyBlock =
641  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
642  bodyBuilder(builder, result.location,
643  bodyBlock->getArguments().slice(0, rank),
644  bodyBlock->getArguments()[rank],
645  bodyBlock->getArguments().drop_front(rank + 1));
646 }
647 
649  auto t = getTensor().getType().cast<RankedTensorType>();
650  auto args = getBody()->getArguments();
651 
652  if (static_cast<size_t>(t.getRank()) + 1 + getInitArgs().size() !=
653  args.size())
654  return emitError("Unmatched number of arguments in the block");
655 
656  if (getNumResults() != getInitArgs().size())
657  return emitError("Mismatch in number of init arguments and results");
658 
659  if (getResultTypes() != getInitArgs().getTypes())
660  return emitError("Mismatch in types of init arguments and results");
661 
662  auto yield = cast<YieldOp>(getBody()->getTerminator());
663  if (yield.getNumOperands() != getNumResults() ||
664  yield.getOperands().getTypes() != getResultTypes())
665  return emitError("Mismatch in types of yield values and results");
666 
667  for (int64_t i = 0, e = t.getRank(); i < e; i++)
668  if (args[i].getType() != IndexType::get(getContext()))
669  emitError(
670  llvm::formatv("Expecting Index type for argument at index {0}", i));
671 
672  auto elemTp = t.getElementType();
673  auto valueTp = args[t.getRank()].getType();
674  if (elemTp != valueTp)
675  emitError(llvm::formatv("Unmatched element type between input tensor and "
676  "block argument, expected:{0}, got: {1}",
677  elemTp, valueTp));
678  return success();
679 }
680 
682  Type inputType = getX().getType();
683  LogicalResult regionResult = success();
684 
685  // Check correct number of block arguments and return type.
686  Region &formula = getRegion();
687  regionResult = verifyNumBlockArgs(this, formula, "reduce",
688  TypeRange{inputType, inputType}, inputType);
689  if (failed(regionResult))
690  return regionResult;
691 
692  return success();
693 }
694 
696  Builder b(getContext());
697 
698  Type inputType = getX().getType();
699  Type boolType = b.getI1Type();
700  LogicalResult regionResult = success();
701 
702  // Check correct number of block arguments and return type.
703  Region &formula = getRegion();
704  regionResult = verifyNumBlockArgs(this, formula, "select",
705  TypeRange{inputType}, boolType);
706  if (failed(regionResult))
707  return regionResult;
708 
709  return success();
710 }
711 
713  if (getXs().empty())
714  return emitError("need at least one xs buffer.");
715 
716  auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
717 
718  Type xtp = getXs().front().getType().cast<MemRefType>().getElementType();
719  auto checkTypes = [&](ValueRange operands,
720  bool checkEleType = true) -> LogicalResult {
721  for (Value opnd : operands) {
722  MemRefType mtp = opnd.getType().cast<MemRefType>();
723  int64_t dim = mtp.getShape()[0];
724  // We can't check the size of dynamic dimension at compile-time, but all
725  // xs and ys should have a dimension not less than n at runtime.
726  if (n && !ShapedType::isDynamic(dim) && dim < n.value())
727  return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
728  ": {0} < {1}",
729  dim, n.value()));
730 
731  if (checkEleType && xtp != mtp.getElementType())
732  return emitError("mismatch xs element types");
733  }
734  return success();
735  };
736 
737  LogicalResult result = checkTypes(getXs());
738  if (failed(result))
739  return result;
740 
741  if (n)
742  return checkTypes(getYs(), false);
743 
744  return success();
745 }
746 
748  auto cn = getN().getDefiningOp<arith::ConstantIndexOp>();
749  // We can't check the size of the buffers when n or buffer dimensions aren't
750  // compile-time constants.
751  if (!cn)
752  return success();
753 
754  uint64_t n = cn.value();
755  uint64_t nx = 1;
756  if (auto nxAttr = getNxAttr()) {
757  nx = nxAttr.getInt();
758  if (nx < 1)
759  emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
760  }
761  uint64_t ny = 0;
762  if (auto nyAttr = getNyAttr()) {
763  ny = nyAttr.getInt();
764  }
765 
766  auto checkDim = [&](Value v, uint64_t min, const char *message) {
767  MemRefType tp = v.getType().cast<MemRefType>();
768  int64_t dim = tp.getShape()[0];
769  if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) {
770  emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min));
771  }
772  };
773 
774  checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
775 
776  for (Value opnd : getYs()) {
777  checkDim(opnd, n, "Expected dimension(y) >= n");
778  }
779 
780  return success();
781 }
782 
784  // Check for compatible parent.
785  auto *parentOp = (*this)->getParentOp();
786  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
787  isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
788  isa<ForeachOp>(parentOp))
789  return success();
790 
791  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
792  "reduce, select or foreach");
793 }
794 
795 //===----------------------------------------------------------------------===//
796 // TensorDialect Methods.
797 //===----------------------------------------------------------------------===//
798 
799 void SparseTensorDialect::initialize() {
800  addAttributes<
801 #define GET_ATTRDEF_LIST
802 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
803  >();
804  addOperations<
805 #define GET_OP_LIST
806 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
807  >();
808 }
809 
810 #define GET_OP_CLASSES
811 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
812 
813 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static constexpr const bool value
Operation::operand_range getIndices(Operation *op)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static bool acceptBitWidth(unsigned bitWidth)
static LogicalResult isMatchingWidth(Value result, unsigned width)
static LogicalResult isInBounds(uint64_t dim, Value tensor)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
unsigned getNumDims() const
Definition: AffineMap.cpp:306
unsigned getNumResults() const
Definition: AffineMap.cpp:314
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:524
This base class exposes generic asm parser hooks, usable across the various derived parsers.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgListType getArguments()
Definition: Block.h:76
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IndexType getIndexType()
Definition: Builders.cpp:56
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:33
bool isIndex() const
Definition: Types.cpp:30
U dyn_cast() const
Definition: Types.h:270
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool isCompressedDim(RankedTensorType type, uint64_t d)
Convenience function to test for compressed dimension (0 <= d < rank).
Definition: SparseTensor.h:69
uint64_t toStoredDim(const SparseTensorEncodingAttr &enc, uint64_t d)
bool isUniqueDim(RankedTensorType type, uint64_t d)
Convenience function to test for unique property in the given dimension (0 <= d < rank).
Definition: SparseTensor.h:90
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
DimLevelType getDimLevelType(const SparseTensorEncodingAttr &enc, uint64_t d)
Definition: SparseTensor.h:49
bool isUniqueCOOType(RankedTensorType tp)
Returns true iff the given type is a type for a COO tensor with the last dimension level type being u...
bool isSingletonDim(RankedTensorType type, uint64_t d)
Convenience function to test for singleton dimension (0 <= d < rank).
Definition: SparseTensor.h:74
uint64_t toOrigDim(const SparseTensorEncodingAttr &enc, uint64_t d)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.