MLIR  14.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // TensorDialect Attribute Methods.
23 //===----------------------------------------------------------------------===//
24 
25 #define GET_ATTRDEF_CLASSES
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
27 
28 static bool acceptBitWidth(unsigned bitWidth) {
29  switch (bitWidth) {
30  case 0:
31  case 8:
32  case 16:
33  case 32:
34  case 64:
35  return true;
36  default:
37  return false;
38  }
39 }
40 
41 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
42  if (failed(parser.parseLess()))
43  return {};
44  // Parse the data as a dictionary.
45  DictionaryAttr dict;
46  if (failed(parser.parseAttribute(dict)))
47  return {};
48  if (failed(parser.parseGreater()))
49  return {};
50  // Process the data from the parsed dictionary value into struct-like data.
52  AffineMap map = {};
53  unsigned ptr = 0;
54  unsigned ind = 0;
55  for (const NamedAttribute &attr : dict) {
56  if (attr.getName() == "dimLevelType") {
57  auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
58  if (!arrayAttr) {
59  parser.emitError(parser.getNameLoc(),
60  "expected an array for dimension level types");
61  return {};
62  }
63  for (auto i : arrayAttr) {
64  auto strAttr = i.dyn_cast<StringAttr>();
65  if (!strAttr) {
66  parser.emitError(parser.getNameLoc(),
67  "expected a string value in dimension level types");
68  return {};
69  }
70  auto strVal = strAttr.getValue();
71  if (strVal == "dense") {
72  dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
73  } else if (strVal == "compressed") {
74  dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
75  } else if (strVal == "singleton") {
76  dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
77  } else {
78  parser.emitError(parser.getNameLoc(),
79  "unexpected dimension level type: ")
80  << strVal;
81  return {};
82  }
83  }
84  } else if (attr.getName() == "dimOrdering") {
85  auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
86  if (!affineAttr) {
87  parser.emitError(parser.getNameLoc(),
88  "expected an affine map for dimension ordering");
89  return {};
90  }
91  map = affineAttr.getValue();
92  } else if (attr.getName() == "pointerBitWidth") {
93  auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
94  if (!intAttr) {
95  parser.emitError(parser.getNameLoc(),
96  "expected an integral pointer bitwidth");
97  return {};
98  }
99  ptr = intAttr.getInt();
100  } else if (attr.getName() == "indexBitWidth") {
101  auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
102  if (!intAttr) {
103  parser.emitError(parser.getNameLoc(),
104  "expected an integral index bitwidth");
105  return {};
106  }
107  ind = intAttr.getInt();
108  } else {
109  parser.emitError(parser.getNameLoc(), "unexpected key: ")
110  << attr.getName().strref();
111  return {};
112  }
113  }
114  // Construct struct-like storage for attribute.
115  return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
116  map, ptr, ind);
117 }
118 
119 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
120  // Print the struct-like storage in dictionary fashion.
121  printer << "<{ dimLevelType = [ ";
122  for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
123  switch (getDimLevelType()[i]) {
124  case DimLevelType::Dense:
125  printer << "\"dense\"";
126  break;
127  case DimLevelType::Compressed:
128  printer << "\"compressed\"";
129  break;
130  case DimLevelType::Singleton:
131  printer << "\"singleton\"";
132  break;
133  }
134  if (i != e - 1)
135  printer << ", ";
136  }
137  printer << " ]";
138  if (getDimOrdering())
139  printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
140  printer << ", pointerBitWidth = " << getPointerBitWidth()
141  << ", indexBitWidth = " << getIndexBitWidth() << " }>";
142 }
143 
146  ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
147  unsigned pointerBitWidth, unsigned indexBitWidth) {
148  if (!acceptBitWidth(pointerBitWidth))
149  return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
150  if (!acceptBitWidth(indexBitWidth))
151  return emitError() << "unexpected index bitwidth: " << indexBitWidth;
152  if (dimOrdering) {
153  if (!dimOrdering.isPermutation())
154  return emitError()
155  << "expected a permutation affine map for dimension ordering";
156  if (dimOrdering.getNumResults() != dimLevelType.size())
157  return emitError() << "unexpected mismatch in ordering and dimension "
158  "level types size";
159  }
160  return success();
161 }
162 
163 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
164  ArrayRef<int64_t> shape, Type elementType,
166  // Check structural integrity.
167  if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
168  getPointerBitWidth(), getIndexBitWidth())))
169  return failure();
170  // Check integrity with tensor type specifics. Dimension ordering is optional,
171  // but we always should have dimension level types for the full rank.
172  unsigned size = shape.size();
173  if (size == 0)
174  return emitError() << "expected non-scalar sparse tensor";
175  if (getDimOrdering() && getDimOrdering().getNumResults() != size)
176  return emitError() << "expected an affine map of size " << size
177  << " for dimension ordering";
178  if (getDimLevelType().size() != size)
179  return emitError() << "expected an array of size " << size
180  << " for dimension level types";
181  return success();
182 }
183 
184 SparseTensorEncodingAttr
186  if (auto ttp = type.dyn_cast<RankedTensorType>())
187  return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
188  return nullptr;
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // TensorDialect Operations.
193 //===----------------------------------------------------------------------===//
194 
196  IntegerAttr constantAttr;
197  if (matchPattern(dim, m_Constant(&constantAttr))) {
198  unsigned d = constantAttr.getInt();
199  if (d >= tensor.getType().cast<RankedTensorType>().getRank())
200  return failure();
201  }
202  return success(); // in bounds, or symbolic
203 }
204 
205 static LogicalResult isMatchingWidth(Value result, unsigned width) {
206  Type etp = result.getType().cast<MemRefType>().getElementType();
207  if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
208  return success();
209  return failure();
210 }
211 
213  if (!getSparseTensorEncoding(op.result().getType()))
214  return op.emitError("expected a sparse tensor result");
215  return success();
216 }
217 
218 static LogicalResult verify(InitOp op) {
219  if (!getSparseTensorEncoding(op.result().getType()))
220  return op.emitError("expected a sparse tensor result");
221  RankedTensorType ttp = op.getType().cast<RankedTensorType>();
222  unsigned rank = ttp.getRank();
223  if (rank != op.sizes().size())
224  return op.emitError("unexpected mismatch between tensor rank and sizes: ")
225  << rank << " vs. " << op.sizes().size();
226  auto shape = ttp.getShape();
227  for (unsigned i = 0; i < rank; i++) {
228  if (shape[i] == ShapedType::kDynamicSize)
229  continue;
230  IntegerAttr constantAttr;
231  if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) ||
232  constantAttr.getInt() != shape[i]) {
233  return op.emitError("unexpected mismatch with static dimension size ")
234  << shape[i];
235  }
236  }
237  return success();
238 }
239 
240 static LogicalResult verify(ConvertOp op) {
241  if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
242  if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
243  if (tp1.getRank() != tp2.getRank())
244  return op.emitError("unexpected conversion mismatch in rank");
245  auto shape1 = tp1.getShape();
246  auto shape2 = tp2.getShape();
247  // Accept size matches between the source and the destination type
248  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
249  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
250  for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
251  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
252  return op.emitError("unexpected conversion mismatch in dimension ")
253  << d;
254  }
255  return success();
256  }
257  }
258  return op.emitError("unexpected type in convert");
259 }
260 
261 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
262  if (getType() == source().getType())
263  return source();
264  return {};
265 }
266 
267 static LogicalResult verify(ToPointersOp op) {
268  if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
269  if (failed(isInBounds(op.dim(), op.tensor())))
270  return op.emitError("requested pointers dimension out of bounds");
271  if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
272  return op.emitError("unexpected type for pointers");
273  return success();
274  }
275  return op.emitError("expected a sparse tensor to get pointers");
276 }
277 
278 static LogicalResult verify(ToIndicesOp op) {
279  if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
280  if (failed(isInBounds(op.dim(), op.tensor())))
281  return op.emitError("requested indices dimension out of bounds");
282  if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
283  return op.emitError("unexpected type for indices");
284  return success();
285  }
286  return op.emitError("expected a sparse tensor to get indices");
287 }
288 
289 static LogicalResult verify(ToValuesOp op) {
290  if (!getSparseTensorEncoding(op.tensor().getType()))
291  return op.emitError("expected a sparse tensor to get values");
292  RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
293  MemRefType mtp = op.result().getType().cast<MemRefType>();
294  if (ttp.getElementType() != mtp.getElementType())
295  return op.emitError("unexpected mismatch in element types");
296  return success();
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // TensorDialect Management Operations.
301 //===----------------------------------------------------------------------===//
302 
303 static LogicalResult verify(LexInsertOp op) {
304  if (!getSparseTensorEncoding(op.tensor().getType()))
305  return op.emitError("expected a sparse tensor for insertion");
306  return success();
307 }
308 
309 static LogicalResult verify(ExpandOp op) {
310  if (!getSparseTensorEncoding(op.tensor().getType()))
311  return op.emitError("expected a sparse tensor for expansion");
312  return success();
313 }
314 
315 static LogicalResult verify(CompressOp op) {
316  if (!getSparseTensorEncoding(op.tensor().getType()))
317  return op.emitError("expected a sparse tensor for compression");
318  return success();
319 }
320 
321 static LogicalResult verify(LoadOp op) {
322  if (!getSparseTensorEncoding(op.tensor().getType()))
323  return op.emitError("expected a sparse tensor to materialize");
324  return success();
325 }
326 
327 static LogicalResult verify(ReleaseOp op) {
328  if (!getSparseTensorEncoding(op.tensor().getType()))
329  return op.emitError("expected a sparse tensor to release");
330  return success();
331 }
332 
333 static LogicalResult verify(OutOp op) {
334  if (!getSparseTensorEncoding(op.tensor().getType()))
335  return op.emitError("expected a sparse tensor for output");
336  return success();
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // TensorDialect Methods.
341 //===----------------------------------------------------------------------===//
342 
343 void SparseTensorDialect::initialize() {
344  addAttributes<
345 #define GET_ATTRDEF_LIST
346 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
347  >();
348  addOperations<
349 #define GET_OP_LIST
350 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
351  >();
352 }
353 
354 #define GET_OP_CLASSES
355 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
Include the generated interface declarations.
static LogicalResult isMatchingWidth(Value result, unsigned width)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
void print(OpAsmPrinter &p, FunctionLibraryOp op)
Definition: Shape.cpp:1111
static bool acceptBitWidth(unsigned bitWidth)
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
Op vectorized into a new Op whose results will replace original Op&#39;s results.
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool isIndex() const
Definition: Types.cpp:28
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
unsigned getNumResults() const
Definition: AffineMap.cpp:302
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
This base class exposes generic asm parser hooks, usable across the various derived parsers...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
T getChecked(llvm::SMLoc loc, ParamsT &&... params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getType() const
Return the type of this value.
Definition: Value.h:117
static LogicalResult isInBounds(Value dim, Value tensor)
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static LogicalResult verify(NewOp op)
This base class exposes generic asm printer hooks, usable across the various derived printers...
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:61
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:513
U cast() const
Definition: Types.h:250