MLIR  19.0.0git
SparseTensorConversion.cpp
Go to the documentation of this file.
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44  if (getSparseTensorEncoding(type) != nullptr)
46  return std::nullopt;
47 }
48 
49 /// Generates call to lookup a level-size. N.B., this only generates
50 /// the raw function call, and therefore (intentionally) does not perform
51 /// any dim<->lvl conversion or other logic.
52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53  uint64_t lvl) {
54  StringRef name = "sparseLvlSize";
55  SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56  Type iTp = builder.getIndexType();
57  return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58  .getResult(0);
59 }
60 
61 /// Generates call to lookup a dimension-size. N.B., this only generates
62 /// the raw function call, and therefore (intentionally) does not perform
63 /// any dim<->lvl conversion or other logic.
64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65  uint64_t dim) {
66  StringRef name = "sparseDimSize";
67  SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68  Type iTp = builder.getIndexType();
69  return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70  .getResult(0);
71 }
72 
73 /// Looks up a level-size by returning a statically-computed constant
74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76  SparseTensorType stt, Value tensor,
77  Level lvl) {
78  // Only sparse tensors have "levels" to query.
79  assert(stt.hasEncoding());
80  // TODO: The following implementation only handles permutations;
81  // we'll need to generalize this to handle arbitrary AffineExpr.
82  //
83  // There's no need to assert `isPermutation` here: because
84  // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85  // which is all we care about (for supporting permutations).
86  const Dimension dim =
87  stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88  const Size sz = stt.getDynamicDimSize(dim);
89  if (!ShapedType::isDynamic(sz))
90  return constantIndex(builder, loc, sz);
91  // If we cannot statically compute the size from the shape, then we
92  // must dynamically query it. (In principle we could also dynamically
93  // compute it, but since we already did so to construct the `tensor`
94  // in the first place, we might as well query rather than recompute.)
95  return genLvlSizeCall(builder, loc, tensor, lvl);
96 }
97 
98 /// Looks up a dimension-size by returning a constant from the shape
99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101 /// of dense tensors).
102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103  SparseTensorType stt, Value tensor,
104  Dimension dim) {
105  const Size sz = stt.getDynamicDimSize(dim);
106  if (!ShapedType::isDynamic(sz))
107  return constantIndex(builder, loc, sz);
108  if (stt.hasEncoding())
109  return genDimSizeCall(builder, loc, tensor, dim);
110  return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111 }
112 
113 /// Populates the array with the dimension-sizes of the given tensor.
114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
115  Value tensor, SmallVectorImpl<Value> &out) {
116  const Dimension dimRank = stt.getDimRank();
117  out.clear();
118  out.reserve(dimRank);
119  for (Dimension d = 0; d < dimRank; d++)
120  out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121 }
122 
123 /// Returns an array with the dimension-sizes of the given tensor.
124 /// If the *tensor* parameters is null, the tensor type is assumed to have a
125 /// static shape.
126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127  SparseTensorType stt,
128  Value tensor = Value()) {
129  SmallVector<Value> out;
130  fillDimSizes(builder, loc, stt, tensor, out);
131  return out;
132 }
133 
134 /// Generates an uninitialized buffer of the given size and type,
135 /// but returns it as type `memref<? x $tp>` (rather than as type
136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137 /// this buffer must be explicitly deallocated by client.
138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139  auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140  return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
141 }
142 
143 /// Generates a temporary buffer for the level-types of the given encoding.
144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145  SparseTensorType stt) {
146  SmallVector<Value> lvlTypes;
147  lvlTypes.reserve(stt.getLvlRank());
148  for (const auto lt : stt.getEncoding().getLvlTypes())
149  lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150  return allocaBuffer(builder, loc, lvlTypes);
151 }
152 
153 /// Extracts the bare (aligned) pointers that point to the tensor.
154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155  Value tensor) {
156  auto buf = genToMemref(builder, loc, tensor);
157  return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158 }
159 
160 /// Generates a temporary buffer for the level-types of the given encoding.
161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162  ValueRange lvlTensors, Value valTensor) {
163  SmallVector<Value> lvlBarePtrs;
164  lvlBarePtrs.reserve(lvlTensors.size() + 1);
165  // Passing in lvl buffer pointers.
166  for (const auto lvl : lvlTensors)
167  lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168 
169  // Passing in value buffer pointers.
170  lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171  Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172  loc, allocaBuffer(builder, loc, lvlBarePtrs));
173  Value idxCast =
174  builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175  return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176  idxCast);
177 }
178 
179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180 /// the "swiss army knife" method of the sparse runtime support library
181 /// for materializing sparse tensors into the computation. This abstraction
182 /// reduces the need for modifications when the API changes.
183 class NewCallParams final {
184 public:
185  /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186  NewCallParams(OpBuilder &builder, Location loc)
187  : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188 
189  /// Initializes all static parameters (i.e., those which indicate
190  /// type-level information such as the encoding and sizes), generating
191  /// MLIR buffers as needed, and returning `this` for method chaining.
192  NewCallParams &genBuffers(SparseTensorType stt,
193  ArrayRef<Value> dimSizesValues,
194  Value dimSizesBuffer = Value()) {
195  assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196  // Sparsity annotations.
197  params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198  // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199  params[kParamDimSizes] = dimSizesBuffer
200  ? dimSizesBuffer
201  : allocaBuffer(builder, loc, dimSizesValues);
202  SmallVector<Value> lvlSizesValues; // unused
203  params[kParamLvlSizes] = genMapBuffers(
204  builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205  lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206  // Secondary and primary types encoding.
207  const auto enc = stt.getEncoding();
208  params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209  params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210  params[kParamValTp] =
211  constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212  // Return `this` for method chaining.
213  return *this;
214  }
215 
216  /// Checks whether all the static parameters have been initialized.
217  bool isInitialized() const {
218  for (unsigned i = 0; i < kNumStaticParams; ++i)
219  if (!params[i])
220  return false;
221  return true;
222  }
223 
224  /// Generates a function call, with the current static parameters
225  /// and the given dynamic arguments.
226  Value genNewCall(Action action, Value ptr = Value()) {
227  assert(isInitialized() && "Must initialize before genNewCall");
228  StringRef name = "newSparseTensor";
229  params[kParamAction] = constantAction(builder, loc, action);
230  params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
231  return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232  .getResult(0);
233  }
234 
235 private:
236  static constexpr unsigned kNumStaticParams = 8;
237  static constexpr unsigned kNumDynamicParams = 2;
238  static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239  static constexpr unsigned kParamDimSizes = 0;
240  static constexpr unsigned kParamLvlSizes = 1;
241  static constexpr unsigned kParamLvlTypes = 2;
242  static constexpr unsigned kParamDim2Lvl = 3;
243  static constexpr unsigned kParamLvl2Dim = 4;
244  static constexpr unsigned kParamPosTp = 5;
245  static constexpr unsigned kParamCrdTp = 6;
246  static constexpr unsigned kParamValTp = 7;
247  static constexpr unsigned kParamAction = 8;
248  static constexpr unsigned kParamPtr = 9;
249 
250  OpBuilder &builder;
251  Location loc;
252  Type pTp;
253  Value params[kNumParams];
254 };
255 
256 /// Generates a call to obtain the values array.
257 static Value genValuesCall(OpBuilder &builder, Location loc,
258  SparseTensorType stt, Value ptr) {
259  auto eltTp = stt.getElementType();
260  auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261  SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262  return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263  .getResult(0);
264 }
265 
266 /// Generates a call to obtain the positions array.
267 static Value genPositionsCall(OpBuilder &builder, Location loc,
268  SparseTensorType stt, Value ptr, Level l) {
269  Type posTp = stt.getPosType();
270  auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271  Value lvl = constantIndex(builder, loc, l);
272  SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
275  .getResult(0);
276 }
277 
278 /// Generates a call to obtain the coordindates array.
279 static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280  SparseTensorType stt, Value ptr, Level l) {
281  Type crdTp = stt.getCrdType();
282  auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283  Value lvl = constantIndex(builder, loc, l);
284  SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
287  .getResult(0);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // Conversion rules.
292 //===----------------------------------------------------------------------===//
293 
294 /// Sparse conversion rule for returns.
295 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
296 public:
299  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
300  ConversionPatternRewriter &rewriter) const override {
301  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
302  return success();
303  }
304 };
305 
306 /// Sparse conversion rule for accessing level-sizes.
307 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
308 public:
311  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
312  ConversionPatternRewriter &rewriter) const override {
313  const auto stt = getSparseTensorType(op.getSource());
314  // Only rewrite sparse DimOp.
315  if (!stt.hasEncoding())
316  return failure();
317 
318  // Only rewrite DimOp with constant index.
319  std::optional<int64_t> lvl = op.getConstantLvlIndex();
320 
321  if (!lvl)
322  return failure();
323 
324  // By now, if the level size is constant, the operation should have already
325  // been folded by LvlOp's folder, so we generate the call unconditionally.
326  Value src = adaptor.getOperands()[0];
327  rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
328  return success();
329  }
330 };
331 
332 /// Sparse conversion rule for trivial tensor casts.
333 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
334 public:
337  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
338  ConversionPatternRewriter &rewriter) const override {
339  // Only rewrite identically annotated source/dest.
340  auto encDst = getSparseTensorEncoding(op.getType());
341  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
342  if (!encDst || encDst != encSrc)
343  return failure();
344  rewriter.replaceOp(op, adaptor.getOperands());
345  return success();
346  }
347 };
348 
349 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
350 public:
353  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
354  ConversionPatternRewriter &rewriter) const override {
355  // Simply fold the operation.
356  rewriter.replaceOp(op, adaptor.getSource());
357  return success();
358  }
359 };
360 
361 /// Sparse conversion rule for the new operator.
362 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
363 public:
366  matchAndRewrite(NewOp op, OpAdaptor adaptor,
367  ConversionPatternRewriter &rewriter) const override {
368  Location loc = op.getLoc();
369  const auto stt = getSparseTensorType(op);
370  if (!stt.hasEncoding())
371  return failure();
372  // Construct the `reader` opening method calls.
373  SmallVector<Value> dimSizesValues;
374  Value dimSizesBuffer;
375  Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
376  dimSizesValues, dimSizesBuffer);
377  // Use the `reader` to parse the file.
378  Value tensor = NewCallParams(rewriter, loc)
379  .genBuffers(stt, dimSizesValues, dimSizesBuffer)
380  .genNewCall(Action::kFromReader, reader);
381  // Free the memory for `reader`.
382  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
384  rewriter.replaceOp(op, tensor);
385  return success();
386  }
387 };
388 
389 /// Sparse conversion rule for the alloc operator.
390 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
391 class SparseTensorAllocConverter
392  : public OpConversionPattern<bufferization::AllocTensorOp> {
393 public:
396  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
397  ConversionPatternRewriter &rewriter) const override {
398  const auto stt = getSparseTensorType(op);
399  if (!stt.hasEncoding())
400  return failure();
401  if (op.getCopy())
402  return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
403  // Gather all dimension sizes as SSA values.
404  Location loc = op.getLoc();
405  const Dimension dimRank = stt.getDimRank();
406  SmallVector<Value> dimSizesValues;
407  dimSizesValues.reserve(dimRank);
408  unsigned operandCtr = 0;
409  for (Dimension d = 0; d < dimRank; d++) {
410  dimSizesValues.push_back(
411  stt.isDynamicDim(d)
412  ? adaptor.getOperands()[operandCtr++]
413  : constantIndex(rewriter, loc, op.getStaticSize(d)));
414  }
415  // Generate the call to construct empty tensor. The sizes are
416  // explicitly defined by the arguments to the alloc operator.
417  rewriter.replaceOp(op, NewCallParams(rewriter, loc)
418  .genBuffers(stt, dimSizesValues)
419  .genNewCall(Action::kEmpty));
420  return success();
421  }
422 };
423 
424 /// Sparse conversion rule for the empty tensor.
425 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
426 public:
429  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
430  ConversionPatternRewriter &rewriter) const override {
431  Location loc = op.getLoc();
432  const auto stt = getSparseTensorType(op);
433  if (!stt.hasEncoding())
434  return failure();
435  // Gather all dimension sizes as SSA values.
436  const Dimension dimRank = stt.getDimRank();
437  SmallVector<Value> dimSizesValues;
438  dimSizesValues.reserve(dimRank);
439  auto shape = op.getType().getShape();
440  unsigned operandCtr = 0;
441  for (Dimension d = 0; d < dimRank; d++) {
442  dimSizesValues.push_back(stt.isDynamicDim(d)
443  ? adaptor.getOperands()[operandCtr++]
444  : constantIndex(rewriter, loc, shape[d]));
445  }
446  // Generate the call to construct empty tensor. The sizes are
447  // explicitly defined by the arguments to the alloc operator.
448  rewriter.replaceOp(op, NewCallParams(rewriter, loc)
449  .genBuffers(stt, dimSizesValues)
450  .genNewCall(Action::kEmpty));
451  return success();
452  }
453 };
454 
455 /// Sparse conversion rule for the convert operator.
456 class SparseTensorReorderCOOConverter
457  : public OpConversionPattern<ReorderCOOOp> {
458 public:
460 
462  matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
463  ConversionPatternRewriter &rewriter) const override {
464  const Location loc = op->getLoc();
465  const auto srcTp = getSparseTensorType(op.getInputCoo());
466  const auto dstTp = getSparseTensorType(op);
467 
468  const Value src = adaptor.getInputCoo();
469 
470  NewCallParams params(rewriter, loc);
471  SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
472  rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
473  .genNewCall(Action::kSortCOOInPlace, src));
474 
475  return success();
476  }
477 };
478 
479 /// Sparse conversion rule for the dealloc operator.
480 class SparseTensorDeallocConverter
481  : public OpConversionPattern<bufferization::DeallocTensorOp> {
482 public:
485  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override {
487  if (!getSparseTensorType(op.getTensor()).hasEncoding())
488  return failure();
489  StringRef name = "delSparseTensor";
490  createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
492  rewriter.eraseOp(op);
493  return success();
494  }
495 };
496 
497 /// Sparse conversion rule for position accesses.
498 class SparseTensorToPositionsConverter
499  : public OpConversionPattern<ToPositionsOp> {
500 public:
503  matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
504  ConversionPatternRewriter &rewriter) const override {
505  auto stt = getSparseTensorType(op.getTensor());
506  auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
507  adaptor.getTensor(), op.getLevel());
508  rewriter.replaceOp(op, poss);
509  return success();
510  }
511 };
512 
513 /// Sparse conversion rule for coordinate accesses.
514 class SparseTensorToCoordinatesConverter
515  : public OpConversionPattern<ToCoordinatesOp> {
516 public:
519  matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
520  ConversionPatternRewriter &rewriter) const override {
521  auto stt = getSparseTensorType(op.getTensor());
522  auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
523  adaptor.getTensor(), op.getLevel());
524  // Cast the MemRef type to the type expected by the users, though these
525  // two types should be compatible at runtime.
526  if (op.getType() != crds.getType())
527  crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
528  rewriter.replaceOp(op, crds);
529  return success();
530  }
531 };
532 
533 /// Sparse conversion rule for value accesses.
534 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
535 public:
538  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
539  ConversionPatternRewriter &rewriter) const override {
540  auto stt = getSparseTensorType(op.getTensor());
541  auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
542  rewriter.replaceOp(op, vals);
543  return success();
544  }
545 };
546 
547 /// Sparse conversion rule for number of entries operator.
548 class SparseNumberOfEntriesConverter
549  : public OpConversionPattern<NumberOfEntriesOp> {
550 public:
553  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
554  ConversionPatternRewriter &rewriter) const override {
555  // Query values array size for the actually stored values size.
556  auto stt = getSparseTensorType(op.getTensor());
557  auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
558  auto zero = constantIndex(rewriter, op.getLoc(), 0);
559  rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
560  return success();
561  }
562 };
563 
564 /// Sparse conversion rule for tensor rematerialization.
565 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
566 public:
569  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
570  ConversionPatternRewriter &rewriter) const override {
571  if (op.getHasInserts()) {
572  // Finalize any pending insertions.
573  StringRef name = "endLexInsert";
574  createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
576  }
577  rewriter.replaceOp(op, adaptor.getOperands());
578  return success();
579  }
580 };
581 
582 /// Sparse conversion rule for the insertion operator.
583 class SparseTensorInsertConverter
584  : public OpConversionPattern<tensor::InsertOp> {
585 public:
588  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const override {
590  // Note that the current regime only allows for strict lexicographic
591  // coordinate order. All values are passed by reference through stack
592  // allocated memrefs.
593  Location loc = op->getLoc();
594  const auto stt = getSparseTensorType(op.getDest());
595 
596  // Dense tensor insertion.
597  if (!stt.hasEncoding())
598  return failure();
599 
600  assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
601  const auto elemTp = stt.getElementType();
602  const Level lvlRank = stt.getLvlRank();
603  Value lvlCoords, vref;
604  {
605  OpBuilder::InsertionGuard guard(rewriter);
606  Operation *loop = op;
607  // Finds the outermost loop.
608  while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
609  loop = l;
610 
611  if (llvm::isa<LoopLikeOpInterface>(loop)) {
612  // Hoists alloca outside the loop to avoid stack overflow.
613  rewriter.setInsertionPoint(loop);
614  }
615  lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
616  vref = genAllocaScalar(rewriter, loc, elemTp);
617  }
618  storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
619  rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
620  SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
621  createFuncCall(rewriter, loc, name, {},
622  {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
623  rewriter.replaceOp(op, adaptor.getDest());
624  return success();
625  }
626 };
627 
628 /// Sparse conversion rule for the expand operator.
629 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
630 public:
633  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
634  ConversionPatternRewriter &rewriter) const override {
635  Location loc = op->getLoc();
636  const auto srcTp = getSparseTensorType(op.getTensor());
637  Type eltType = srcTp.getElementType();
638  Type boolType = rewriter.getIntegerType(1);
639  Type idxType = rewriter.getIndexType();
640  // All initialization should be done on entry of the loop nest.
641  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
642  // Get the cardinality of valid coordinates for the innermost level.
643  Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
644  srcTp.getLvlRank() - 1);
645  // Allocate temporary buffers for values, filled-switch, and coordinates.
646  // We do not use stack buffers for this, since the expanded size may
647  // be rather large (as it envelops a single expanded dense dimension).
648  Value values = genAlloc(rewriter, loc, sz, eltType);
649  Value filled = genAlloc(rewriter, loc, sz, boolType);
650  Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
651  Value zero = constantZero(rewriter, loc, idxType);
652  // Reset the values/filled-switch to all-zero/false. Note that this
653  // introduces an O(N) operation into the computation, but this reset
654  // operation is amortized over the innermost loops for the access
655  // pattern expansion. As noted in the operation doc, we would like
656  // to amortize this setup cost even between kernels.
657  rewriter.create<linalg::FillOp>(
658  loc, ValueRange{constantZero(rewriter, loc, eltType)},
659  ValueRange{values});
660  rewriter.create<linalg::FillOp>(
661  loc, ValueRange{constantZero(rewriter, loc, boolType)},
662  ValueRange{filled});
663  // Replace expansion op with these buffers and initial coordinate.
664  assert(op.getNumResults() == 4);
665  rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
666  return success();
667  }
668 };
669 
670 /// Sparse conversion rule for the compress operator.
671 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
672 public:
675  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
676  ConversionPatternRewriter &rewriter) const override {
677  Location loc = op->getLoc();
678  // Note that this method call resets the values/filled-switch back to
679  // all-zero/false by only iterating over the set elements, so the
680  // complexity remains proportional to the sparsity of the expanded
681  // access pattern.
682  Value values = adaptor.getValues();
683  Value filled = adaptor.getFilled();
684  Value added = adaptor.getAdded();
685  Value count = adaptor.getCount();
686  Value tensor = adaptor.getTensor();
687  const auto stt = getSparseTensorType(op.getTensor());
688  const Type elemTp = stt.getElementType();
689  const Level lvlRank = stt.getLvlRank();
690  auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
691  storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
692  SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
693  createFuncCall(rewriter, loc, name, {},
694  {tensor, lvlCoords, values, filled, added, count},
696  rewriter.replaceOp(op, adaptor.getTensor());
697  // Deallocate the buffers on exit of the loop nest.
698  Operation *parent = getTop(op);
699  rewriter.setInsertionPointAfter(parent);
700  rewriter.create<memref::DeallocOp>(loc, values);
701  rewriter.create<memref::DeallocOp>(loc, filled);
702  rewriter.create<memref::DeallocOp>(loc, added);
703  return success();
704  }
705 };
706 
707 /// Sparse conversion rule for the sparse_tensor.assemble operator.
708 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
709 public:
712  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
713  ConversionPatternRewriter &rewriter) const override {
714  const Location loc = op->getLoc();
715  const auto dstTp = getSparseTensorType(op.getResult());
716  assert(dstTp.hasStaticDimShape());
717  SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
718  // Use a library method to transfer the external buffers from
719  // clients to the internal SparseTensorStorage. Since we cannot
720  // assume clients transfer ownership of the buffers, this method
721  // will copy all data over into a new SparseTensorStorage.
722  Value dst =
723  NewCallParams(rewriter, loc)
724  .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
725  .genNewCall(Action::kPack,
726  genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
727  adaptor.getValues()));
728  rewriter.replaceOp(op, dst);
729  return success();
730  }
731 };
732 
733 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
734 class SparseTensorDisassembleConverter
735  : public OpConversionPattern<DisassembleOp> {
736 public:
739  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
740  ConversionPatternRewriter &rewriter) const override {
741  // We simply expose the buffers to the external client. This
742  // assumes the client only reads the buffers (usually copying it
743  // to the external data structures, such as numpy arrays).
744  Location loc = op->getLoc();
745  auto stt = getSparseTensorType(op.getTensor());
746  SmallVector<Value> retVal;
747  SmallVector<Value> retLen;
748  // Get the positions and coordinates buffers.
749  const Level lvlRank = stt.getLvlRank();
750  Level trailCOOLen = 0;
751  for (Level l = 0; l < lvlRank; l++) {
752  if (!stt.isUniqueLvl(l) &&
753  (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
754  // A `(loose)compressed_nu` level marks the start of trailing COO
755  // start level. Since the target coordinate buffer used for trailing
756  // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
757  // scheme, we cannot simply use the internal buffers.
758  trailCOOLen = lvlRank - l;
759  break;
760  }
761  if (stt.isWithPos(l)) {
762  auto poss =
763  genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
764  auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
765  auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
766  retVal.push_back(poss);
767  retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
768  }
769  if (stt.isWithCrd(l)) {
770  auto crds =
771  genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
772  auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
773  auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
774  retVal.push_back(crds);
775  retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
776  }
777  }
778  // Handle AoS vs. SoA mismatch for COO.
779  if (trailCOOLen != 0) {
780  uint64_t cooStartLvl = lvlRank - trailCOOLen;
781  assert(!stt.isUniqueLvl(cooStartLvl) &&
782  (stt.isCompressedLvl(cooStartLvl) ||
783  stt.isLooseCompressedLvl(cooStartLvl)));
784  // Positions.
785  auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
786  cooStartLvl);
787  auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
788  auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
789  retVal.push_back(poss);
790  retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
791  // Coordinates, copied over with:
792  // for (i = 0; i < crdLen; i++)
793  // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
794  auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
795  auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
796  cooStartLvl);
797  auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
798  cooStartLvl + 1);
799  auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
800  auto two = constantIndex(rewriter, loc, 2);
801  auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
802  Type indexType = rewriter.getIndexType();
803  auto zero = constantZero(rewriter, loc, indexType);
804  auto one = constantOne(rewriter, loc, indexType);
805  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
806  auto idx = forOp.getInductionVar();
807  rewriter.setInsertionPointToStart(forOp.getBody());
808  auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
809  auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
810  SmallVector<Value> args;
811  args.push_back(idx);
812  args.push_back(zero);
813  rewriter.create<memref::StoreOp>(loc, c0, buf, args);
814  args[1] = one;
815  rewriter.create<memref::StoreOp>(loc, c1, buf, args);
816  rewriter.setInsertionPointAfter(forOp);
817  auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
818  retVal.push_back(buf);
819  retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
820  }
821  // Get the values buffer last.
822  auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
823  auto valLenTp = op.getValLen().getType();
824  auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
825  retVal.push_back(vals);
826  retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
827 
828  // Converts MemRefs back to Tensors.
829  assert(retVal.size() + retLen.size() == op.getNumResults());
830  for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
831  auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
832  retVal[i] =
833  rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
834  }
835 
836  // Appends the actual memory length used in each buffer returned.
837  retVal.append(retLen.begin(), retLen.end());
838  rewriter.replaceOp(op, retVal);
839  return success();
840  }
841 };
842 
843 struct SparseHasRuntimeLibraryConverter
844  : public OpConversionPattern<HasRuntimeLibraryOp> {
847  matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
848  ConversionPatternRewriter &rewriter) const override {
849  auto i1Type = rewriter.getI1Type();
850  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
851  op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
852  return success();
853  }
854 };
855 
856 } // namespace
857 
858 //===----------------------------------------------------------------------===//
859 // Sparse tensor type conversion into opaque pointer.
860 //===----------------------------------------------------------------------===//
861 
863  addConversion([](Type type) { return type; });
864  addConversion(convertSparseTensorTypes);
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // Public method for populating conversion rules.
869 //===----------------------------------------------------------------------===//
870 
871 /// Populates the given patterns list with conversion rules required for
872 /// the sparsification of linear algebra operations.
874  RewritePatternSet &patterns) {
875  patterns
876  .add<SparseReturnConverter, SparseTensorLvlOpConverter,
877  SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
878  SparseTensorAllocConverter, SparseTensorEmptyConverter,
879  SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
880  SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
881  SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
882  SparseTensorLoadConverter, SparseTensorInsertConverter,
883  SparseTensorExpandConverter, SparseTensorCompressConverter,
884  SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
885  SparseHasRuntimeLibraryConverter>(typeConverter,
886  patterns.getContext());
887 }
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:401
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:88
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Definition: CodegenUtils.h:397
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?...
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Definition: CodegenUtils.h:377
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:323
Action
The actions performed by @newSparseTensor.
Definition: Enums.h:146
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
Definition: CodegenUtils.h:364
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
Definition: CodegenUtils.h:384
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
Definition: CodegenUtils.h:390
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26