MLIR 23.0.0git
SparseTensorConversion.cpp
Go to the documentation of this file.
1//===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor primitives into calls into a runtime
10// support library. Sparse tensor types are converted into opaque pointers
11// to the underlying sparse storage schemes. The use of opaque pointers
12// together with runtime support library keeps the conversion relatively
13// simple, but at the expense of IR opacity, which obscures opportunities
14// for subsequent optimization of the IR. An alternative is provided by
15// the SparseTensorCodegen pass.
16//
17//===----------------------------------------------------------------------===//
18
19#include "Utils/CodegenUtils.h"
20
32
33using namespace mlir;
34using namespace mlir::sparse_tensor;
35
36namespace {
37
38//===----------------------------------------------------------------------===//
39// Helper methods.
40//===----------------------------------------------------------------------===//
41
42/// Maps each sparse tensor type to an opaque pointer.
43static std::optional<Type> convertSparseTensorTypes(Type type) {
44 if (getSparseTensorEncoding(type) != nullptr)
45 return LLVM::LLVMPointerType::get(type.getContext());
46 return std::nullopt;
47}
48
49/// Generates call to lookup a level-size. N.B., this only generates
50/// the raw function call, and therefore (intentionally) does not perform
51/// any dim<->lvl conversion or other logic.
52static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53 uint64_t lvl) {
54 StringRef name = "sparseLvlSize";
55 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56 Type iTp = builder.getIndexType();
57 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58 .getResult(0);
59}
60
61/// Generates call to lookup a dimension-size. N.B., this only generates
62/// the raw function call, and therefore (intentionally) does not perform
63/// any dim<->lvl conversion or other logic.
64static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65 uint64_t dim) {
66 StringRef name = "sparseDimSize";
67 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68 Type iTp = builder.getIndexType();
69 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70 .getResult(0);
71}
72
73/// Looks up a level-size by returning a statically-computed constant
74/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
77 Level lvl) {
78 // Only sparse tensors have "levels" to query.
79 assert(stt.hasEncoding());
80 // TODO: The following implementation only handles permutations;
81 // we'll need to generalize this to handle arbitrary AffineExpr.
82 //
83 // There's no need to assert `isPermutation` here: because
84 // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85 // which is all we care about (for supporting permutations).
86 const Dimension dim =
87 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88 const Size sz = stt.getDynamicDimSize(dim);
89 if (ShapedType::isStatic(sz))
90 return constantIndex(builder, loc, sz);
91 // If we cannot statically compute the size from the shape, then we
92 // must dynamically query it. (In principle we could also dynamically
93 // compute it, but since we already did so to construct the `tensor`
94 // in the first place, we might as well query rather than recompute.)
95 return genLvlSizeCall(builder, loc, tensor, lvl);
96}
97
98/// Looks up a dimension-size by returning a constant from the shape
99/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101/// of dense tensors).
102static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
104 Dimension dim) {
105 const Size sz = stt.getDynamicDimSize(dim);
106 if (ShapedType::isStatic(sz))
107 return constantIndex(builder, loc, sz);
108 if (stt.hasEncoding())
109 return genDimSizeCall(builder, loc, tensor, dim);
110 return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111}
112
113/// Populates the array with the dimension-sizes of the given tensor.
114static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
116 const Dimension dimRank = stt.getDimRank();
117 out.clear();
118 out.reserve(dimRank);
119 for (Dimension d = 0; d < dimRank; d++)
120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121}
122
123/// Returns an array with the dimension-sizes of the given tensor.
124/// If the *tensor* parameters is null, the tensor type is assumed to have a
125/// static shape.
126static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
128 Value tensor = Value()) {
130 fillDimSizes(builder, loc, stt, tensor, out);
131 return out;
132}
133
134/// Generates an uninitialized buffer of the given size and type,
135/// but returns it as type `memref<? x $tp>` (rather than as type
136/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137/// this buffer must be explicitly deallocated by client.
138static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140 return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz});
141}
142
143/// Generates a temporary buffer for the level-types of the given encoding.
144static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145 SparseTensorType stt) {
146 SmallVector<Value> lvlTypes;
147 lvlTypes.reserve(stt.getLvlRank());
148 for (const auto lt : stt.getEncoding().getLvlTypes())
149 lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150 return allocaBuffer(builder, loc, lvlTypes);
151}
152
153/// Extracts the bare (aligned) pointers that point to the tensor.
154static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155 Value tensor) {
156 auto buf = genToMemref(builder, loc, tensor);
157 return memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, buf);
158}
159
160/// Generates a temporary buffer for the level-types of the given encoding.
161static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162 ValueRange lvlTensors, Value valTensor) {
163 SmallVector<Value> lvlBarePtrs;
164 lvlBarePtrs.reserve(lvlTensors.size() + 1);
165 // Passing in lvl buffer pointers.
166 for (const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168
169 // Passing in value buffer pointers.
170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171 Value idxPtr = memref::ExtractAlignedPointerAsIndexOp::create(
172 builder, loc, allocaBuffer(builder, loc, lvlBarePtrs));
173 Value idxCast =
174 arith::IndexCastOp::create(builder, loc, builder.getI64Type(), idxPtr);
175 return LLVM::IntToPtrOp::create(builder, loc, getOpaquePointerType(builder),
176 idxCast);
177}
178
179/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180/// the "swiss army knife" method of the sparse runtime support library
181/// for materializing sparse tensors into the computation. This abstraction
182/// reduces the need for modifications when the API changes.
183class NewCallParams final {
184public:
185 /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186 NewCallParams(OpBuilder &builder, Location loc)
187 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188
189 /// Initializes all static parameters (i.e., those which indicate
190 /// type-level information such as the encoding and sizes), generating
191 /// MLIR buffers as needed, and returning `this` for method chaining.
192 NewCallParams &genBuffers(SparseTensorType stt,
193 ArrayRef<Value> dimSizesValues,
194 Value dimSizesBuffer = Value()) {
195 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196 // Sparsity annotations.
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199 params[kParamDimSizes] = dimSizesBuffer
200 ? dimSizesBuffer
201 : allocaBuffer(builder, loc, dimSizesValues);
202 SmallVector<Value> lvlSizesValues; // unused
203 params[kParamLvlSizes] = genMapBuffers(
204 builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206 // Secondary and primary types encoding.
207 const auto enc = stt.getEncoding();
208 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210 params[kParamValTp] =
211 constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212 // Return `this` for method chaining.
213 return *this;
214 }
215
216 /// Checks whether all the static parameters have been initialized.
217 bool isInitialized() const {
218 for (unsigned i = 0; i < kNumStaticParams; ++i)
219 if (!params[i])
220 return false;
221 return true;
222 }
223
224 /// Generates a function call, with the current static parameters
225 /// and the given dynamic arguments.
226 Value genNewCall(Action action, Value ptr = Value()) {
227 assert(isInitialized() && "Must initialize before genNewCall");
228 StringRef name = "newSparseTensor";
229 params[kParamAction] = constantAction(builder, loc, action);
230 params[kParamPtr] = ptr ? ptr : LLVM::ZeroOp::create(builder, loc, pTp);
231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232 .getResult(0);
233 }
234
235private:
236 static constexpr unsigned kNumStaticParams = 8;
237 static constexpr unsigned kNumDynamicParams = 2;
238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr unsigned kParamDimSizes = 0;
240 static constexpr unsigned kParamLvlSizes = 1;
241 static constexpr unsigned kParamLvlTypes = 2;
242 static constexpr unsigned kParamDim2Lvl = 3;
243 static constexpr unsigned kParamLvl2Dim = 4;
244 static constexpr unsigned kParamPosTp = 5;
245 static constexpr unsigned kParamCrdTp = 6;
246 static constexpr unsigned kParamValTp = 7;
247 static constexpr unsigned kParamAction = 8;
248 static constexpr unsigned kParamPtr = 9;
249
250 OpBuilder &builder;
251 Location loc;
252 Type pTp;
253 Value params[kNumParams];
254};
255
256/// Generates a call to obtain the values array.
257static Value genValuesCall(OpBuilder &builder, Location loc,
259 auto eltTp = stt.getElementType();
260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262 return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263 .getResult(0);
264}
265
266/// Generates a call to obtain the positions array.
267static Value genPositionsCall(OpBuilder &builder, Location loc,
269 Type posTp = stt.getPosType();
270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271 Value lvl = constantIndex(builder, loc, l);
272 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
275 .getResult(0);
276}
277
278/// Generates a call to obtain the coordinates array.
279static Value genCoordinatesCall(OpBuilder &builder, Location loc,
281 Type crdTp = stt.getCrdType();
282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283 Value lvl = constantIndex(builder, loc, l);
284 SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
287 .getResult(0);
288}
289
290/// Generates a call to obtain the coordinates array (AoS view).
291static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
293 Level l) {
294 Type crdTp = stt.getCrdType();
295 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296 Value lvl = constantIndex(builder, loc, l);
297 SmallString<25> name{"sparseCoordinatesBuffer",
299 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
301 .getResult(0);
302}
303
304//===----------------------------------------------------------------------===//
305// Conversion rules.
306//===----------------------------------------------------------------------===//
307
308/// Sparse conversion rule for returns.
309class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
310public:
311 using OpConversionPattern::OpConversionPattern;
312 LogicalResult
313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
316 return success();
317 }
318};
319
320/// Sparse conversion rule for accessing level-sizes.
321class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
322public:
323 using OpConversionPattern::OpConversionPattern;
324 LogicalResult
325 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter) const override {
327 const auto stt = getSparseTensorType(op.getSource());
328 // Only rewrite sparse DimOp.
329 if (!stt.hasEncoding())
330 return failure();
331
332 // Only rewrite DimOp with constant index.
333 std::optional<int64_t> lvl = op.getConstantLvlIndex();
334
335 if (!lvl)
336 return failure();
337
338 // By now, if the level size is constant, the operation should have already
339 // been folded by LvlOp's folder, so we generate the call unconditionally.
340 Value src = adaptor.getOperands()[0];
341 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
342 return success();
343 }
344};
345
346/// Sparse conversion rule for trivial tensor casts.
347class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
348public:
349 using OpConversionPattern::OpConversionPattern;
350 LogicalResult
351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 // Only rewrite identically annotated source/dest.
354 auto encDst = getSparseTensorEncoding(op.getType());
355 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
356 if (!encDst || encDst != encSrc)
357 return failure();
358 rewriter.replaceOp(op, adaptor.getOperands());
359 return success();
360 }
361};
362
363class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
364public:
365 using OpConversionPattern::OpConversionPattern;
366 LogicalResult
367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const override {
369 // Simply fold the operation.
370 rewriter.replaceOp(op, adaptor.getSource());
371 return success();
372 }
373};
374
375/// Sparse conversion rule for the new operator.
376class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
377public:
378 using OpConversionPattern::OpConversionPattern;
379 LogicalResult
380 matchAndRewrite(NewOp op, OpAdaptor adaptor,
381 ConversionPatternRewriter &rewriter) const override {
382 Location loc = op.getLoc();
383 const auto stt = getSparseTensorType(op);
384 if (!stt.hasEncoding())
385 return failure();
386 // Verify that the element type is supported by the runtime library.
388 return rewriter.notifyMatchFailure(op, "unsupported element type");
389 // Construct the `reader` opening method calls.
390 SmallVector<Value> dimSizesValues;
391 Value dimSizesBuffer;
392 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
393 dimSizesValues, dimSizesBuffer);
394 // Use the `reader` to parse the file.
395 Value tensor = NewCallParams(rewriter, loc)
396 .genBuffers(stt, dimSizesValues, dimSizesBuffer)
397 .genNewCall(Action::kFromReader, reader);
398 // Free the memory for `reader`.
399 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
400 EmitCInterface::Off);
401 rewriter.replaceOp(op, tensor);
402 return success();
403 }
404};
405
406/// Sparse conversion rule for the alloc operator.
407/// TODO(springerm): remove when bufferization.alloc_tensor is gone
408class SparseTensorAllocConverter
409 : public OpConversionPattern<bufferization::AllocTensorOp> {
410public:
411 using OpConversionPattern::OpConversionPattern;
412 LogicalResult
413 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
414 ConversionPatternRewriter &rewriter) const override {
415 const auto stt = getSparseTensorType(op);
416 if (!stt.hasEncoding())
417 return failure();
418 if (op.getCopy())
419 return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
420 // Gather all dimension sizes as SSA values.
421 Location loc = op.getLoc();
422 const Dimension dimRank = stt.getDimRank();
423 SmallVector<Value> dimSizesValues;
424 dimSizesValues.reserve(dimRank);
425 unsigned operandCtr = 0;
426 for (Dimension d = 0; d < dimRank; d++) {
427 dimSizesValues.push_back(
428 stt.isDynamicDim(d)
429 ? adaptor.getOperands()[operandCtr++]
430 : constantIndex(rewriter, loc, op.getStaticSize(d)));
431 }
432 // Generate the call to construct empty tensor. The sizes are
433 // explicitly defined by the arguments to the alloc operator.
434 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
435 .genBuffers(stt, dimSizesValues)
436 .genNewCall(Action::kEmpty));
437 return success();
438 }
439};
440
441/// Sparse conversion rule for the empty tensor.
442class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
443public:
444 using OpConversionPattern::OpConversionPattern;
445 LogicalResult
446 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter) const override {
448 Location loc = op.getLoc();
449 const auto stt = getSparseTensorType(op);
450 if (!stt.hasEncoding())
451 return failure();
452 // Gather all dimension sizes as SSA values.
453 const Dimension dimRank = stt.getDimRank();
454 SmallVector<Value> dimSizesValues;
455 dimSizesValues.reserve(dimRank);
456 auto shape = op.getType().getShape();
457 unsigned operandCtr = 0;
458 for (Dimension d = 0; d < dimRank; d++) {
459 dimSizesValues.push_back(stt.isDynamicDim(d)
460 ? adaptor.getOperands()[operandCtr++]
461 : constantIndex(rewriter, loc, shape[d]));
462 }
463 // Generate the call to construct empty tensor. The sizes are
464 // explicitly defined by the arguments to the alloc operator.
465 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
466 .genBuffers(stt, dimSizesValues)
467 .genNewCall(Action::kEmpty));
468 return success();
469 }
470};
471
472/// Sparse conversion rule for the convert operator.
473class SparseTensorReorderCOOConverter
474 : public OpConversionPattern<ReorderCOOOp> {
475public:
476 using OpConversionPattern::OpConversionPattern;
477
478 LogicalResult
479 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
480 ConversionPatternRewriter &rewriter) const override {
481 const Location loc = op->getLoc();
482 const auto srcTp = getSparseTensorType(op.getInputCoo());
483 const auto dstTp = getSparseTensorType(op);
484
485 const Value src = adaptor.getInputCoo();
486
487 NewCallParams params(rewriter, loc);
488 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
489 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
490 .genNewCall(Action::kSortCOOInPlace, src));
491
492 return success();
493 }
494};
495
496/// Sparse conversion rule for the dealloc operator.
497class SparseTensorDeallocConverter
498 : public OpConversionPattern<bufferization::DeallocTensorOp> {
499public:
500 using OpConversionPattern::OpConversionPattern;
501 LogicalResult
502 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
503 ConversionPatternRewriter &rewriter) const override {
504 if (!getSparseTensorType(op.getTensor()).hasEncoding())
505 return failure();
506 StringRef name = "delSparseTensor";
507 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
508 EmitCInterface::Off);
509 rewriter.eraseOp(op);
510 return success();
511 }
512};
513
514/// Sparse conversion rule for position accesses.
515class SparseTensorToPositionsConverter
516 : public OpConversionPattern<ToPositionsOp> {
517public:
518 using OpConversionPattern::OpConversionPattern;
519 LogicalResult
520 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
521 ConversionPatternRewriter &rewriter) const override {
522 auto stt = getSparseTensorType(op.getTensor());
523 auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
524 adaptor.getTensor(), op.getLevel());
525 rewriter.replaceOp(op, poss);
526 return success();
527 }
528};
529
530/// Sparse conversion rule for coordinate accesses.
531class SparseTensorToCoordinatesConverter
532 : public OpConversionPattern<ToCoordinatesOp> {
533public:
534 using OpConversionPattern::OpConversionPattern;
535 LogicalResult
536 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
537 ConversionPatternRewriter &rewriter) const override {
538 const Location loc = op.getLoc();
539 auto stt = getSparseTensorType(op.getTensor());
540 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
541 op.getLevel());
542 // Cast the MemRef type to the type expected by the users, though these
543 // two types should be compatible at runtime.
544 if (op.getType() != crds.getType())
545 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
546 rewriter.replaceOp(op, crds);
547 return success();
548 }
549};
550
551/// Sparse conversion rule for coordinate accesses (AoS style).
552class SparseToCoordinatesBufferConverter
553 : public OpConversionPattern<ToCoordinatesBufferOp> {
554public:
555 using OpConversionPattern::OpConversionPattern;
556 LogicalResult
557 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter) const override {
559 const Location loc = op.getLoc();
560 auto stt = getSparseTensorType(op.getTensor());
561 auto crds = genCoordinatesBufferCall(
562 rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
563 // Cast the MemRef type to the type expected by the users, though these
564 // two types should be compatible at runtime.
565 if (op.getType() != crds.getType())
566 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
567 rewriter.replaceOp(op, crds);
568 return success();
569 }
570};
571
572/// Sparse conversion rule for value accesses.
573class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
574public:
575 using OpConversionPattern::OpConversionPattern;
576 LogicalResult
577 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
578 ConversionPatternRewriter &rewriter) const override {
579 auto stt = getSparseTensorType(op.getTensor());
580 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
581 rewriter.replaceOp(op, vals);
582 return success();
583 }
584};
585
586/// Sparse conversion rule for number of entries operator.
587class SparseNumberOfEntriesConverter
588 : public OpConversionPattern<NumberOfEntriesOp> {
589public:
590 using OpConversionPattern::OpConversionPattern;
591 LogicalResult
592 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
593 ConversionPatternRewriter &rewriter) const override {
594 // Query values array size for the actually stored values size.
595 auto stt = getSparseTensorType(op.getTensor());
596 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
597 auto zero = constantIndex(rewriter, op.getLoc(), 0);
598 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
599 return success();
600 }
601};
602
603/// Sparse conversion rule for tensor rematerialization.
604class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
605public:
606 using OpConversionPattern::OpConversionPattern;
607 LogicalResult
608 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const override {
610 if (op.getHasInserts()) {
611 // Finalize any pending insertions.
612 StringRef name = "endLexInsert";
613 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
614 EmitCInterface::Off);
615 }
616 rewriter.replaceOp(op, adaptor.getOperands());
617 return success();
618 }
619};
620
621/// Sparse conversion rule for the insertion operator.
622class SparseTensorInsertConverter
623 : public OpConversionPattern<tensor::InsertOp> {
624public:
625 using OpConversionPattern::OpConversionPattern;
626 LogicalResult
627 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter) const override {
629 // Note that the current regime only allows for strict lexicographic
630 // coordinate order. All values are passed by reference through stack
631 // allocated memrefs.
632 Location loc = op->getLoc();
633 const auto stt = getSparseTensorType(op.getDest());
634
635 // Dense tensor insertion.
636 if (!stt.hasEncoding())
637 return failure();
638
639 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
640 const auto elemTp = stt.getElementType();
641 const Level lvlRank = stt.getLvlRank();
642 Value lvlCoords, vref;
643 {
644 OpBuilder::InsertionGuard guard(rewriter);
645 Operation *loop = op;
646 // Finds the outermost loop.
647 while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
648 loop = l;
649
650 if (llvm::isa<LoopLikeOpInterface>(loop)) {
651 // Hoists alloca outside the loop to avoid stack overflow.
652 rewriter.setInsertionPoint(loop);
653 }
654 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
655 vref = genAllocaScalar(rewriter, loc, elemTp);
656 }
657 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
658 memref::StoreOp::create(rewriter, loc, adaptor.getScalar(), vref);
659 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
660 createFuncCall(rewriter, loc, name, {},
661 {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
662 rewriter.replaceOp(op, adaptor.getDest());
663 return success();
664 }
665};
666
667/// Sparse conversion rule for the expand operator.
668class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
669public:
670 using OpConversionPattern::OpConversionPattern;
671 LogicalResult
672 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter) const override {
674 Location loc = op->getLoc();
675 const auto srcTp = getSparseTensorType(op.getTensor());
676 Type eltType = srcTp.getElementType();
677 Type boolType = rewriter.getIntegerType(1);
678 Type idxType = rewriter.getIndexType();
679 // All initialization should be done on entry of the loop nest.
680 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
681 // Get the cardinality of valid coordinates for the innermost level.
682 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
683 srcTp.getLvlRank() - 1);
684 // Allocate temporary buffers for values, filled-switch, and coordinates.
685 // We do not use stack buffers for this, since the expanded size may
686 // be rather large (as it envelops a single expanded dense dimension).
687 Value values = genAlloc(rewriter, loc, sz, eltType);
688 Value filled = genAlloc(rewriter, loc, sz, boolType);
689 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
690 Value zero = constantZero(rewriter, loc, idxType);
691 // Reset the values/filled-switch to all-zero/false. Note that this
692 // introduces an O(N) operation into the computation, but this reset
693 // operation is amortized over the innermost loops for the access
694 // pattern expansion. As noted in the operation doc, we would like
695 // to amortize this setup cost even between kernels.
696 linalg::FillOp::create(rewriter, loc,
697 ValueRange{constantZero(rewriter, loc, eltType)},
698 ValueRange{values});
699 linalg::FillOp::create(rewriter, loc,
700 ValueRange{constantZero(rewriter, loc, boolType)},
701 ValueRange{filled});
702 // Replace expansion op with these buffers and initial coordinate.
703 assert(op.getNumResults() == 4);
704 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
705 return success();
706 }
707};
708
709/// Sparse conversion rule for the compress operator.
710class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
711public:
712 using OpConversionPattern::OpConversionPattern;
713 LogicalResult
714 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const override {
716 Location loc = op->getLoc();
717 // Note that this method call resets the values/filled-switch back to
718 // all-zero/false by only iterating over the set elements, so the
719 // complexity remains proportional to the sparsity of the expanded
720 // access pattern.
721 Value values = adaptor.getValues();
722 Value filled = adaptor.getFilled();
723 Value added = adaptor.getAdded();
724 Value count = adaptor.getCount();
725 Value tensor = adaptor.getTensor();
726 const auto stt = getSparseTensorType(op.getTensor());
727 const Type elemTp = stt.getElementType();
728 const Level lvlRank = stt.getLvlRank();
729 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
730 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
731 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
732 createFuncCall(rewriter, loc, name, {},
733 {tensor, lvlCoords, values, filled, added, count},
734 EmitCInterface::On);
735 Operation *parent = getTop(op);
736 rewriter.setInsertionPointAfter(parent);
737 rewriter.replaceOp(op, adaptor.getTensor());
738 // Deallocate the buffers on exit of the loop nest.
739 memref::DeallocOp::create(rewriter, loc, values);
740 memref::DeallocOp::create(rewriter, loc, filled);
741 memref::DeallocOp::create(rewriter, loc, added);
742 return success();
743 }
744};
745
746/// Sparse conversion rule for the sparse_tensor.assemble operator.
747class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
748public:
749 using OpConversionPattern::OpConversionPattern;
750 LogicalResult
751 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
752 ConversionPatternRewriter &rewriter) const override {
753 const Location loc = op->getLoc();
754 const auto dstTp = getSparseTensorType(op.getResult());
755 assert(dstTp.hasStaticDimShape());
756 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
757 // Use a library method to transfer the external buffers from
758 // clients to the internal SparseTensorStorage. Since we cannot
759 // assume clients transfer ownership of the buffers, this method
760 // will copy all data over into a new SparseTensorStorage.
761 Value dst =
762 NewCallParams(rewriter, loc)
763 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
764 .genNewCall(Action::kPack,
765 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
766 adaptor.getValues()));
767 rewriter.replaceOp(op, dst);
768 return success();
769 }
770};
771
772/// Sparse conversion rule for the sparse_tensor.disassemble operator.
773/// Note that the current implementation simply exposes the buffers to
774/// the external client. This assumes the client only reads the buffers
775/// (usually copying it to the external data structures, such as numpy
776/// arrays). The semantics of the disassemble operation technically
777/// require that the copying is done here already using the out-levels
778/// and out-values clause.
779class SparseTensorDisassembleConverter
780 : public OpConversionPattern<DisassembleOp> {
781public:
782 using OpConversionPattern::OpConversionPattern;
783 LogicalResult
784 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter) const override {
786 Location loc = op->getLoc();
787 auto stt = getSparseTensorType(op.getTensor());
788 SmallVector<Value> retVal;
789 SmallVector<Value> retLen;
790 // Get the positions and coordinates buffers.
791 const Level lvlRank = stt.getLvlRank();
792 Level trailCOOLen = 0;
793 for (Level l = 0; l < lvlRank; l++) {
794 if (!stt.isUniqueLvl(l) &&
795 (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
796 // A `(loose)compressed_nu` level marks the start of trailing COO
797 // start level. Since the target coordinate buffer used for trailing
798 // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
799 // scheme, we cannot simply use the internal buffers.
800 trailCOOLen = lvlRank - l;
801 break;
802 }
803 if (stt.isWithPos(l)) {
804 auto poss =
805 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
806 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
807 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
808 retVal.push_back(poss);
809 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
810 }
811 if (stt.isWithCrd(l)) {
812 auto crds =
813 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
814 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
815 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
816 retVal.push_back(crds);
817 retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
818 }
819 }
820 // Handle AoS vs. SoA mismatch for COO.
821 if (trailCOOLen != 0) {
822 uint64_t cooStartLvl = lvlRank - trailCOOLen;
823 assert(!stt.isUniqueLvl(cooStartLvl) &&
824 (stt.isCompressedLvl(cooStartLvl) ||
825 stt.isLooseCompressedLvl(cooStartLvl)));
826 // Positions.
827 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
828 cooStartLvl);
829 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
830 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
831 retVal.push_back(poss);
832 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
833 // Coordinates, copied over with:
834 // for (i = 0; i < crdLen; i++)
835 // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
836 auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
837 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
838 cooStartLvl);
839 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
840 cooStartLvl + 1);
841 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
842 auto two = constantIndex(rewriter, loc, 2);
843 auto bufLen = arith::MulIOp::create(rewriter, loc, crdLen, two);
844 Type indexType = rewriter.getIndexType();
845 auto zero = constantZero(rewriter, loc, indexType);
846 auto one = constantOne(rewriter, loc, indexType);
847 scf::ForOp forOp = scf::ForOp::create(rewriter, loc, zero, crdLen, one);
848 auto idx = forOp.getInductionVar();
849 rewriter.setInsertionPointToStart(forOp.getBody());
850 auto c0 = memref::LoadOp::create(rewriter, loc, crds0, idx);
851 auto c1 = memref::LoadOp::create(rewriter, loc, crds1, idx);
852 SmallVector<Value> args;
853 args.push_back(idx);
854 args.push_back(zero);
855 memref::StoreOp::create(rewriter, loc, c0, buf, args);
856 args[1] = one;
857 memref::StoreOp::create(rewriter, loc, c1, buf, args);
858 rewriter.setInsertionPointAfter(forOp);
859 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
860 retVal.push_back(buf);
861 retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
862 }
863 // Get the values buffer last.
864 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
865 auto valLenTp = op.getValLen().getType();
866 auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
867 retVal.push_back(vals);
868 retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
869
870 // Converts MemRefs back to Tensors.
871 assert(retVal.size() + retLen.size() == op.getNumResults());
872 for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
873 auto tensor = bufferization::ToTensorOp::create(
874 rewriter, loc,
875 memref::getTensorTypeFromMemRefType(retVal[i].getType()), retVal[i]);
876 retVal[i] =
877 tensor::CastOp::create(rewriter, loc, op.getResultTypes()[i], tensor);
878 }
879
880 // Appends the actual memory length used in each buffer returned.
881 retVal.append(retLen.begin(), retLen.end());
882 rewriter.replaceOp(op, retVal);
883 return success();
884 }
885};
886
887struct SparseHasRuntimeLibraryConverter
888 : public OpConversionPattern<HasRuntimeLibraryOp> {
889 using OpConversionPattern::OpConversionPattern;
890 LogicalResult
891 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
892 ConversionPatternRewriter &rewriter) const override {
893 auto i1Type = rewriter.getI1Type();
894 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
895 op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
896 return success();
897 }
898};
899
900} // namespace
901
902//===----------------------------------------------------------------------===//
903// Sparse tensor type conversion into opaque pointer.
904//===----------------------------------------------------------------------===//
905
907 addConversion([](Type type) { return type; });
908 addConversion(convertSparseTensorTypes);
909}
910
911//===----------------------------------------------------------------------===//
912// Public method for populating conversion rules.
913//===----------------------------------------------------------------------===//
914
915/// Populates the given patterns list with conversion rules required for
916/// the sparsification of linear algebra operations.
918 const TypeConverter &typeConverter, RewritePatternSet &patterns) {
919 patterns
920 .add<SparseReturnConverter, SparseTensorLvlOpConverter,
921 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
922 SparseTensorAllocConverter, SparseTensorEmptyConverter,
923 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
924 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
925 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
926 SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
927 SparseTensorInsertConverter, SparseTensorExpandConverter,
928 SparseTensorCompressConverter, SparseTensorAssembleConverter,
929 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
930 typeConverter, patterns.getContext());
931}
return success()
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
IntegerType getI64Type()
Definition Builders.cpp:69
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Action
The actions performed by @newSparseTensor.
Definition Enums.h:146
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
bool isValidPrimaryType(Type elemTp)
Returns true if the given type is a valid sparse tensor element type supported by the runtime library...
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.