MLIR  20.0.0git
SparseTensorConversion.cpp
Go to the documentation of this file.
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44  if (getSparseTensorEncoding(type) != nullptr)
46  return std::nullopt;
47 }
48 
49 /// Generates call to lookup a level-size. N.B., this only generates
50 /// the raw function call, and therefore (intentionally) does not perform
51 /// any dim<->lvl conversion or other logic.
52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53  uint64_t lvl) {
54  StringRef name = "sparseLvlSize";
55  SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56  Type iTp = builder.getIndexType();
57  return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58  .getResult(0);
59 }
60 
61 /// Generates call to lookup a dimension-size. N.B., this only generates
62 /// the raw function call, and therefore (intentionally) does not perform
63 /// any dim<->lvl conversion or other logic.
64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65  uint64_t dim) {
66  StringRef name = "sparseDimSize";
67  SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68  Type iTp = builder.getIndexType();
69  return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70  .getResult(0);
71 }
72 
73 /// Looks up a level-size by returning a statically-computed constant
74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76  SparseTensorType stt, Value tensor,
77  Level lvl) {
78  // Only sparse tensors have "levels" to query.
79  assert(stt.hasEncoding());
80  // TODO: The following implementation only handles permutations;
81  // we'll need to generalize this to handle arbitrary AffineExpr.
82  //
83  // There's no need to assert `isPermutation` here: because
84  // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85  // which is all we care about (for supporting permutations).
86  const Dimension dim =
87  stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88  const Size sz = stt.getDynamicDimSize(dim);
89  if (!ShapedType::isDynamic(sz))
90  return constantIndex(builder, loc, sz);
91  // If we cannot statically compute the size from the shape, then we
92  // must dynamically query it. (In principle we could also dynamically
93  // compute it, but since we already did so to construct the `tensor`
94  // in the first place, we might as well query rather than recompute.)
95  return genLvlSizeCall(builder, loc, tensor, lvl);
96 }
97 
98 /// Looks up a dimension-size by returning a constant from the shape
99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101 /// of dense tensors).
102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103  SparseTensorType stt, Value tensor,
104  Dimension dim) {
105  const Size sz = stt.getDynamicDimSize(dim);
106  if (!ShapedType::isDynamic(sz))
107  return constantIndex(builder, loc, sz);
108  if (stt.hasEncoding())
109  return genDimSizeCall(builder, loc, tensor, dim);
110  return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111 }
112 
113 /// Populates the array with the dimension-sizes of the given tensor.
114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
115  Value tensor, SmallVectorImpl<Value> &out) {
116  const Dimension dimRank = stt.getDimRank();
117  out.clear();
118  out.reserve(dimRank);
119  for (Dimension d = 0; d < dimRank; d++)
120  out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121 }
122 
123 /// Returns an array with the dimension-sizes of the given tensor.
124 /// If the *tensor* parameters is null, the tensor type is assumed to have a
125 /// static shape.
126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127  SparseTensorType stt,
128  Value tensor = Value()) {
129  SmallVector<Value> out;
130  fillDimSizes(builder, loc, stt, tensor, out);
131  return out;
132 }
133 
134 /// Generates an uninitialized buffer of the given size and type,
135 /// but returns it as type `memref<? x $tp>` (rather than as type
136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137 /// this buffer must be explicitly deallocated by client.
138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139  auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140  return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
141 }
142 
143 /// Generates a temporary buffer for the level-types of the given encoding.
144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145  SparseTensorType stt) {
146  SmallVector<Value> lvlTypes;
147  lvlTypes.reserve(stt.getLvlRank());
148  for (const auto lt : stt.getEncoding().getLvlTypes())
149  lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150  return allocaBuffer(builder, loc, lvlTypes);
151 }
152 
153 /// Extracts the bare (aligned) pointers that point to the tensor.
154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155  Value tensor) {
156  auto buf = genToMemref(builder, loc, tensor);
157  return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158 }
159 
160 /// Generates a temporary buffer for the level-types of the given encoding.
161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162  ValueRange lvlTensors, Value valTensor) {
163  SmallVector<Value> lvlBarePtrs;
164  lvlBarePtrs.reserve(lvlTensors.size() + 1);
165  // Passing in lvl buffer pointers.
166  for (const auto lvl : lvlTensors)
167  lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168 
169  // Passing in value buffer pointers.
170  lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171  Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172  loc, allocaBuffer(builder, loc, lvlBarePtrs));
173  Value idxCast =
174  builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175  return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176  idxCast);
177 }
178 
179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180 /// the "swiss army knife" method of the sparse runtime support library
181 /// for materializing sparse tensors into the computation. This abstraction
182 /// reduces the need for modifications when the API changes.
183 class NewCallParams final {
184 public:
185  /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186  NewCallParams(OpBuilder &builder, Location loc)
187  : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188 
189  /// Initializes all static parameters (i.e., those which indicate
190  /// type-level information such as the encoding and sizes), generating
191  /// MLIR buffers as needed, and returning `this` for method chaining.
192  NewCallParams &genBuffers(SparseTensorType stt,
193  ArrayRef<Value> dimSizesValues,
194  Value dimSizesBuffer = Value()) {
195  assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196  // Sparsity annotations.
197  params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198  // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199  params[kParamDimSizes] = dimSizesBuffer
200  ? dimSizesBuffer
201  : allocaBuffer(builder, loc, dimSizesValues);
202  SmallVector<Value> lvlSizesValues; // unused
203  params[kParamLvlSizes] = genMapBuffers(
204  builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205  lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206  // Secondary and primary types encoding.
207  const auto enc = stt.getEncoding();
208  params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209  params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210  params[kParamValTp] =
211  constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212  // Return `this` for method chaining.
213  return *this;
214  }
215 
216  /// Checks whether all the static parameters have been initialized.
217  bool isInitialized() const {
218  for (unsigned i = 0; i < kNumStaticParams; ++i)
219  if (!params[i])
220  return false;
221  return true;
222  }
223 
224  /// Generates a function call, with the current static parameters
225  /// and the given dynamic arguments.
226  Value genNewCall(Action action, Value ptr = Value()) {
227  assert(isInitialized() && "Must initialize before genNewCall");
228  StringRef name = "newSparseTensor";
229  params[kParamAction] = constantAction(builder, loc, action);
230  params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
231  return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232  .getResult(0);
233  }
234 
235 private:
236  static constexpr unsigned kNumStaticParams = 8;
237  static constexpr unsigned kNumDynamicParams = 2;
238  static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239  static constexpr unsigned kParamDimSizes = 0;
240  static constexpr unsigned kParamLvlSizes = 1;
241  static constexpr unsigned kParamLvlTypes = 2;
242  static constexpr unsigned kParamDim2Lvl = 3;
243  static constexpr unsigned kParamLvl2Dim = 4;
244  static constexpr unsigned kParamPosTp = 5;
245  static constexpr unsigned kParamCrdTp = 6;
246  static constexpr unsigned kParamValTp = 7;
247  static constexpr unsigned kParamAction = 8;
248  static constexpr unsigned kParamPtr = 9;
249 
250  OpBuilder &builder;
251  Location loc;
252  Type pTp;
253  Value params[kNumParams];
254 };
255 
256 /// Generates a call to obtain the values array.
257 static Value genValuesCall(OpBuilder &builder, Location loc,
258  SparseTensorType stt, Value ptr) {
259  auto eltTp = stt.getElementType();
260  auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261  SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262  return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263  .getResult(0);
264 }
265 
266 /// Generates a call to obtain the positions array.
267 static Value genPositionsCall(OpBuilder &builder, Location loc,
268  SparseTensorType stt, Value ptr, Level l) {
269  Type posTp = stt.getPosType();
270  auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271  Value lvl = constantIndex(builder, loc, l);
272  SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
275  .getResult(0);
276 }
277 
278 /// Generates a call to obtain the coordinates array.
279 static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280  SparseTensorType stt, Value ptr, Level l) {
281  Type crdTp = stt.getCrdType();
282  auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283  Value lvl = constantIndex(builder, loc, l);
284  SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
287  .getResult(0);
288 }
289 
290 /// Generates a call to obtain the coordinates array (AoS view).
291 static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
292  SparseTensorType stt, Value ptr,
293  Level l) {
294  Type crdTp = stt.getCrdType();
295  auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296  Value lvl = constantIndex(builder, loc, l);
297  SmallString<25> name{"sparseCoordinatesBuffer",
299  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
301  .getResult(0);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Conversion rules.
306 //===----------------------------------------------------------------------===//
307 
308 /// Sparse conversion rule for returns.
309 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
310 public:
312  LogicalResult
313  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314  ConversionPatternRewriter &rewriter) const override {
315  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
316  return success();
317  }
318 };
319 
320 /// Sparse conversion rule for accessing level-sizes.
321 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
322 public:
324  LogicalResult
325  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326  ConversionPatternRewriter &rewriter) const override {
327  const auto stt = getSparseTensorType(op.getSource());
328  // Only rewrite sparse DimOp.
329  if (!stt.hasEncoding())
330  return failure();
331 
332  // Only rewrite DimOp with constant index.
333  std::optional<int64_t> lvl = op.getConstantLvlIndex();
334 
335  if (!lvl)
336  return failure();
337 
338  // By now, if the level size is constant, the operation should have already
339  // been folded by LvlOp's folder, so we generate the call unconditionally.
340  Value src = adaptor.getOperands()[0];
341  rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
342  return success();
343  }
344 };
345 
346 /// Sparse conversion rule for trivial tensor casts.
347 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
348 public:
350  LogicalResult
351  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352  ConversionPatternRewriter &rewriter) const override {
353  // Only rewrite identically annotated source/dest.
354  auto encDst = getSparseTensorEncoding(op.getType());
355  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
356  if (!encDst || encDst != encSrc)
357  return failure();
358  rewriter.replaceOp(op, adaptor.getOperands());
359  return success();
360  }
361 };
362 
363 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
364 public:
366  LogicalResult
367  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368  ConversionPatternRewriter &rewriter) const override {
369  // Simply fold the operation.
370  rewriter.replaceOp(op, adaptor.getSource());
371  return success();
372  }
373 };
374 
375 /// Sparse conversion rule for the new operator.
376 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
377 public:
379  LogicalResult
380  matchAndRewrite(NewOp op, OpAdaptor adaptor,
381  ConversionPatternRewriter &rewriter) const override {
382  Location loc = op.getLoc();
383  const auto stt = getSparseTensorType(op);
384  if (!stt.hasEncoding())
385  return failure();
386  // Construct the `reader` opening method calls.
387  SmallVector<Value> dimSizesValues;
388  Value dimSizesBuffer;
389  Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
390  dimSizesValues, dimSizesBuffer);
391  // Use the `reader` to parse the file.
392  Value tensor = NewCallParams(rewriter, loc)
393  .genBuffers(stt, dimSizesValues, dimSizesBuffer)
394  .genNewCall(Action::kFromReader, reader);
395  // Free the memory for `reader`.
396  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
398  rewriter.replaceOp(op, tensor);
399  return success();
400  }
401 };
402 
403 /// Sparse conversion rule for the alloc operator.
404 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
405 class SparseTensorAllocConverter
406  : public OpConversionPattern<bufferization::AllocTensorOp> {
407 public:
409  LogicalResult
410  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
411  ConversionPatternRewriter &rewriter) const override {
412  const auto stt = getSparseTensorType(op);
413  if (!stt.hasEncoding())
414  return failure();
415  if (op.getCopy())
416  return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
417  // Gather all dimension sizes as SSA values.
418  Location loc = op.getLoc();
419  const Dimension dimRank = stt.getDimRank();
420  SmallVector<Value> dimSizesValues;
421  dimSizesValues.reserve(dimRank);
422  unsigned operandCtr = 0;
423  for (Dimension d = 0; d < dimRank; d++) {
424  dimSizesValues.push_back(
425  stt.isDynamicDim(d)
426  ? adaptor.getOperands()[operandCtr++]
427  : constantIndex(rewriter, loc, op.getStaticSize(d)));
428  }
429  // Generate the call to construct empty tensor. The sizes are
430  // explicitly defined by the arguments to the alloc operator.
431  rewriter.replaceOp(op, NewCallParams(rewriter, loc)
432  .genBuffers(stt, dimSizesValues)
433  .genNewCall(Action::kEmpty));
434  return success();
435  }
436 };
437 
438 /// Sparse conversion rule for the empty tensor.
439 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
440 public:
442  LogicalResult
443  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const override {
445  Location loc = op.getLoc();
446  const auto stt = getSparseTensorType(op);
447  if (!stt.hasEncoding())
448  return failure();
449  // Gather all dimension sizes as SSA values.
450  const Dimension dimRank = stt.getDimRank();
451  SmallVector<Value> dimSizesValues;
452  dimSizesValues.reserve(dimRank);
453  auto shape = op.getType().getShape();
454  unsigned operandCtr = 0;
455  for (Dimension d = 0; d < dimRank; d++) {
456  dimSizesValues.push_back(stt.isDynamicDim(d)
457  ? adaptor.getOperands()[operandCtr++]
458  : constantIndex(rewriter, loc, shape[d]));
459  }
460  // Generate the call to construct empty tensor. The sizes are
461  // explicitly defined by the arguments to the alloc operator.
462  rewriter.replaceOp(op, NewCallParams(rewriter, loc)
463  .genBuffers(stt, dimSizesValues)
464  .genNewCall(Action::kEmpty));
465  return success();
466  }
467 };
468 
469 /// Sparse conversion rule for the convert operator.
470 class SparseTensorReorderCOOConverter
471  : public OpConversionPattern<ReorderCOOOp> {
472 public:
474 
475  LogicalResult
476  matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
477  ConversionPatternRewriter &rewriter) const override {
478  const Location loc = op->getLoc();
479  const auto srcTp = getSparseTensorType(op.getInputCoo());
480  const auto dstTp = getSparseTensorType(op);
481 
482  const Value src = adaptor.getInputCoo();
483 
484  NewCallParams params(rewriter, loc);
485  SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
486  rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
487  .genNewCall(Action::kSortCOOInPlace, src));
488 
489  return success();
490  }
491 };
492 
493 /// Sparse conversion rule for the dealloc operator.
494 class SparseTensorDeallocConverter
495  : public OpConversionPattern<bufferization::DeallocTensorOp> {
496 public:
498  LogicalResult
499  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
500  ConversionPatternRewriter &rewriter) const override {
501  if (!getSparseTensorType(op.getTensor()).hasEncoding())
502  return failure();
503  StringRef name = "delSparseTensor";
504  createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
506  rewriter.eraseOp(op);
507  return success();
508  }
509 };
510 
511 /// Sparse conversion rule for position accesses.
512 class SparseTensorToPositionsConverter
513  : public OpConversionPattern<ToPositionsOp> {
514 public:
516  LogicalResult
517  matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
518  ConversionPatternRewriter &rewriter) const override {
519  auto stt = getSparseTensorType(op.getTensor());
520  auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521  adaptor.getTensor(), op.getLevel());
522  rewriter.replaceOp(op, poss);
523  return success();
524  }
525 };
526 
527 /// Sparse conversion rule for coordinate accesses.
528 class SparseTensorToCoordinatesConverter
529  : public OpConversionPattern<ToCoordinatesOp> {
530 public:
532  LogicalResult
533  matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
534  ConversionPatternRewriter &rewriter) const override {
535  const Location loc = op.getLoc();
536  auto stt = getSparseTensorType(op.getTensor());
537  auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
538  op.getLevel());
539  // Cast the MemRef type to the type expected by the users, though these
540  // two types should be compatible at runtime.
541  if (op.getType() != crds.getType())
542  crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
543  rewriter.replaceOp(op, crds);
544  return success();
545  }
546 };
547 
548 /// Sparse conversion rule for coordinate accesses (AoS style).
549 class SparseToCoordinatesBufferConverter
550  : public OpConversionPattern<ToCoordinatesBufferOp> {
551 public:
553  LogicalResult
554  matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555  ConversionPatternRewriter &rewriter) const override {
556  const Location loc = op.getLoc();
557  auto stt = getSparseTensorType(op.getTensor());
558  auto crds = genCoordinatesBufferCall(
559  rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
560  // Cast the MemRef type to the type expected by the users, though these
561  // two types should be compatible at runtime.
562  if (op.getType() != crds.getType())
563  crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
564  rewriter.replaceOp(op, crds);
565  return success();
566  }
567 };
568 
569 /// Sparse conversion rule for value accesses.
570 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
571 public:
573  LogicalResult
574  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
575  ConversionPatternRewriter &rewriter) const override {
576  auto stt = getSparseTensorType(op.getTensor());
577  auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
578  rewriter.replaceOp(op, vals);
579  return success();
580  }
581 };
582 
583 /// Sparse conversion rule for number of entries operator.
584 class SparseNumberOfEntriesConverter
585  : public OpConversionPattern<NumberOfEntriesOp> {
586 public:
588  LogicalResult
589  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
590  ConversionPatternRewriter &rewriter) const override {
591  // Query values array size for the actually stored values size.
592  auto stt = getSparseTensorType(op.getTensor());
593  auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
594  auto zero = constantIndex(rewriter, op.getLoc(), 0);
595  rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
596  return success();
597  }
598 };
599 
600 /// Sparse conversion rule for tensor rematerialization.
601 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
602 public:
604  LogicalResult
605  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606  ConversionPatternRewriter &rewriter) const override {
607  if (op.getHasInserts()) {
608  // Finalize any pending insertions.
609  StringRef name = "endLexInsert";
610  createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
612  }
613  rewriter.replaceOp(op, adaptor.getOperands());
614  return success();
615  }
616 };
617 
618 /// Sparse conversion rule for the insertion operator.
619 class SparseTensorInsertConverter
620  : public OpConversionPattern<tensor::InsertOp> {
621 public:
623  LogicalResult
624  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
625  ConversionPatternRewriter &rewriter) const override {
626  // Note that the current regime only allows for strict lexicographic
627  // coordinate order. All values are passed by reference through stack
628  // allocated memrefs.
629  Location loc = op->getLoc();
630  const auto stt = getSparseTensorType(op.getDest());
631 
632  // Dense tensor insertion.
633  if (!stt.hasEncoding())
634  return failure();
635 
636  assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
637  const auto elemTp = stt.getElementType();
638  const Level lvlRank = stt.getLvlRank();
639  Value lvlCoords, vref;
640  {
641  OpBuilder::InsertionGuard guard(rewriter);
642  Operation *loop = op;
643  // Finds the outermost loop.
644  while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
645  loop = l;
646 
647  if (llvm::isa<LoopLikeOpInterface>(loop)) {
648  // Hoists alloca outside the loop to avoid stack overflow.
649  rewriter.setInsertionPoint(loop);
650  }
651  lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
652  vref = genAllocaScalar(rewriter, loc, elemTp);
653  }
654  storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
655  rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
656  SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
657  createFuncCall(rewriter, loc, name, {},
658  {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
659  rewriter.replaceOp(op, adaptor.getDest());
660  return success();
661  }
662 };
663 
664 /// Sparse conversion rule for the expand operator.
665 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
666 public:
668  LogicalResult
669  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
670  ConversionPatternRewriter &rewriter) const override {
671  Location loc = op->getLoc();
672  const auto srcTp = getSparseTensorType(op.getTensor());
673  Type eltType = srcTp.getElementType();
674  Type boolType = rewriter.getIntegerType(1);
675  Type idxType = rewriter.getIndexType();
676  // All initialization should be done on entry of the loop nest.
677  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
678  // Get the cardinality of valid coordinates for the innermost level.
679  Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680  srcTp.getLvlRank() - 1);
681  // Allocate temporary buffers for values, filled-switch, and coordinates.
682  // We do not use stack buffers for this, since the expanded size may
683  // be rather large (as it envelops a single expanded dense dimension).
684  Value values = genAlloc(rewriter, loc, sz, eltType);
685  Value filled = genAlloc(rewriter, loc, sz, boolType);
686  Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
687  Value zero = constantZero(rewriter, loc, idxType);
688  // Reset the values/filled-switch to all-zero/false. Note that this
689  // introduces an O(N) operation into the computation, but this reset
690  // operation is amortized over the innermost loops for the access
691  // pattern expansion. As noted in the operation doc, we would like
692  // to amortize this setup cost even between kernels.
693  rewriter.create<linalg::FillOp>(
694  loc, ValueRange{constantZero(rewriter, loc, eltType)},
695  ValueRange{values});
696  rewriter.create<linalg::FillOp>(
697  loc, ValueRange{constantZero(rewriter, loc, boolType)},
698  ValueRange{filled});
699  // Replace expansion op with these buffers and initial coordinate.
700  assert(op.getNumResults() == 4);
701  rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
702  return success();
703  }
704 };
705 
706 /// Sparse conversion rule for the compress operator.
707 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
708 public:
710  LogicalResult
711  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
712  ConversionPatternRewriter &rewriter) const override {
713  Location loc = op->getLoc();
714  // Note that this method call resets the values/filled-switch back to
715  // all-zero/false by only iterating over the set elements, so the
716  // complexity remains proportional to the sparsity of the expanded
717  // access pattern.
718  Value values = adaptor.getValues();
719  Value filled = adaptor.getFilled();
720  Value added = adaptor.getAdded();
721  Value count = adaptor.getCount();
722  Value tensor = adaptor.getTensor();
723  const auto stt = getSparseTensorType(op.getTensor());
724  const Type elemTp = stt.getElementType();
725  const Level lvlRank = stt.getLvlRank();
726  auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
727  storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
728  SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
729  createFuncCall(rewriter, loc, name, {},
730  {tensor, lvlCoords, values, filled, added, count},
732  rewriter.replaceOp(op, adaptor.getTensor());
733  // Deallocate the buffers on exit of the loop nest.
734  Operation *parent = getTop(op);
735  rewriter.setInsertionPointAfter(parent);
736  rewriter.create<memref::DeallocOp>(loc, values);
737  rewriter.create<memref::DeallocOp>(loc, filled);
738  rewriter.create<memref::DeallocOp>(loc, added);
739  return success();
740  }
741 };
742 
743 /// Sparse conversion rule for the sparse_tensor.assemble operator.
744 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
745 public:
747  LogicalResult
748  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
749  ConversionPatternRewriter &rewriter) const override {
750  const Location loc = op->getLoc();
751  const auto dstTp = getSparseTensorType(op.getResult());
752  assert(dstTp.hasStaticDimShape());
753  SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
754  // Use a library method to transfer the external buffers from
755  // clients to the internal SparseTensorStorage. Since we cannot
756  // assume clients transfer ownership of the buffers, this method
757  // will copy all data over into a new SparseTensorStorage.
758  Value dst =
759  NewCallParams(rewriter, loc)
760  .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
761  .genNewCall(Action::kPack,
762  genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763  adaptor.getValues()));
764  rewriter.replaceOp(op, dst);
765  return success();
766  }
767 };
768 
769 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
770 /// Note that the current implementation simply exposes the buffers to
771 /// the external client. This assumes the client only reads the buffers
772 /// (usually copying it to the external data structures, such as numpy
773 /// arrays). The semantics of the disassemble operation technically
774 /// require that the copying is done here already using the out-levels
775 /// and out-values clause.
776 class SparseTensorDisassembleConverter
777  : public OpConversionPattern<DisassembleOp> {
778 public:
780  LogicalResult
781  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
782  ConversionPatternRewriter &rewriter) const override {
783  Location loc = op->getLoc();
784  auto stt = getSparseTensorType(op.getTensor());
785  SmallVector<Value> retVal;
786  SmallVector<Value> retLen;
787  // Get the positions and coordinates buffers.
788  const Level lvlRank = stt.getLvlRank();
789  Level trailCOOLen = 0;
790  for (Level l = 0; l < lvlRank; l++) {
791  if (!stt.isUniqueLvl(l) &&
792  (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
793  // A `(loose)compressed_nu` level marks the start of trailing COO
794  // start level. Since the target coordinate buffer used for trailing
795  // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
796  // scheme, we cannot simply use the internal buffers.
797  trailCOOLen = lvlRank - l;
798  break;
799  }
800  if (stt.isWithPos(l)) {
801  auto poss =
802  genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
803  auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
804  auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805  retVal.push_back(poss);
806  retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
807  }
808  if (stt.isWithCrd(l)) {
809  auto crds =
810  genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
811  auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
812  auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813  retVal.push_back(crds);
814  retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
815  }
816  }
817  // Handle AoS vs. SoA mismatch for COO.
818  if (trailCOOLen != 0) {
819  uint64_t cooStartLvl = lvlRank - trailCOOLen;
820  assert(!stt.isUniqueLvl(cooStartLvl) &&
821  (stt.isCompressedLvl(cooStartLvl) ||
822  stt.isLooseCompressedLvl(cooStartLvl)));
823  // Positions.
824  auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
825  cooStartLvl);
826  auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
827  auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828  retVal.push_back(poss);
829  retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
830  // Coordinates, copied over with:
831  // for (i = 0; i < crdLen; i++)
832  // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
833  auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834  auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
835  cooStartLvl);
836  auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
837  cooStartLvl + 1);
838  auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
839  auto two = constantIndex(rewriter, loc, 2);
840  auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
841  Type indexType = rewriter.getIndexType();
842  auto zero = constantZero(rewriter, loc, indexType);
843  auto one = constantOne(rewriter, loc, indexType);
844  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
845  auto idx = forOp.getInductionVar();
846  rewriter.setInsertionPointToStart(forOp.getBody());
847  auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
848  auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
849  SmallVector<Value> args;
850  args.push_back(idx);
851  args.push_back(zero);
852  rewriter.create<memref::StoreOp>(loc, c0, buf, args);
853  args[1] = one;
854  rewriter.create<memref::StoreOp>(loc, c1, buf, args);
855  rewriter.setInsertionPointAfter(forOp);
856  auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857  retVal.push_back(buf);
858  retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
859  }
860  // Get the values buffer last.
861  auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862  auto valLenTp = op.getValLen().getType();
863  auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
864  retVal.push_back(vals);
865  retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
866 
867  // Converts MemRefs back to Tensors.
868  assert(retVal.size() + retLen.size() == op.getNumResults());
869  for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870  auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
871  retVal[i] =
872  rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
873  }
874 
875  // Appends the actual memory length used in each buffer returned.
876  retVal.append(retLen.begin(), retLen.end());
877  rewriter.replaceOp(op, retVal);
878  return success();
879  }
880 };
881 
882 struct SparseHasRuntimeLibraryConverter
883  : public OpConversionPattern<HasRuntimeLibraryOp> {
885  LogicalResult
886  matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
887  ConversionPatternRewriter &rewriter) const override {
888  auto i1Type = rewriter.getI1Type();
889  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
890  op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
891  return success();
892  }
893 };
894 
895 } // namespace
896 
897 //===----------------------------------------------------------------------===//
898 // Sparse tensor type conversion into opaque pointer.
899 //===----------------------------------------------------------------------===//
900 
902  addConversion([](Type type) { return type; });
903  addConversion(convertSparseTensorTypes);
904 }
905 
906 //===----------------------------------------------------------------------===//
907 // Public method for populating conversion rules.
908 //===----------------------------------------------------------------------===//
909 
910 /// Populates the given patterns list with conversion rules required for
911 /// the sparsification of linear algebra operations.
913  const TypeConverter &typeConverter, RewritePatternSet &patterns) {
914  patterns
915  .add<SparseReturnConverter, SparseTensorLvlOpConverter,
916  SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
917  SparseTensorAllocConverter, SparseTensorEmptyConverter,
918  SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
919  SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
920  SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
921  SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
922  SparseTensorInsertConverter, SparseTensorExpandConverter,
923  SparseTensorCompressConverter, SparseTensorAssembleConverter,
924  SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
925  typeConverter, patterns.getContext());
926 }
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:96
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Definition: CodegenUtils.h:394
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?...
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Definition: CodegenUtils.h:374
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Action
The actions performed by @newSparseTensor.
Definition: Enums.h:146
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
Definition: CodegenUtils.h:361
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
Definition: CodegenUtils.h:381
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
Definition: CodegenUtils.h:387
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...