MLIR 22.0.0git
SparseTensorConversion.cpp
Go to the documentation of this file.
1//===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor primitives into calls into a runtime
10// support library. Sparse tensor types are converted into opaque pointers
11// to the underlying sparse storage schemes. The use of opaque pointers
12// together with runtime support library keeps the conversion relatively
13// simple, but at the expense of IR opacity, which obscures opportunities
14// for subsequent optimization of the IR. An alternative is provided by
15// the SparseTensorCodegen pass.
16//
17//===----------------------------------------------------------------------===//
18
19#include "Utils/CodegenUtils.h"
20
32
33using namespace mlir;
34using namespace mlir::sparse_tensor;
35
36namespace {
37
38//===----------------------------------------------------------------------===//
39// Helper methods.
40//===----------------------------------------------------------------------===//
41
42/// Maps each sparse tensor type to an opaque pointer.
43static std::optional<Type> convertSparseTensorTypes(Type type) {
44 if (getSparseTensorEncoding(type) != nullptr)
45 return LLVM::LLVMPointerType::get(type.getContext());
46 return std::nullopt;
47}
48
49/// Generates call to lookup a level-size. N.B., this only generates
50/// the raw function call, and therefore (intentionally) does not perform
51/// any dim<->lvl conversion or other logic.
52static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53 uint64_t lvl) {
54 StringRef name = "sparseLvlSize";
55 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56 Type iTp = builder.getIndexType();
57 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58 .getResult(0);
59}
60
61/// Generates call to lookup a dimension-size. N.B., this only generates
62/// the raw function call, and therefore (intentionally) does not perform
63/// any dim<->lvl conversion or other logic.
64static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65 uint64_t dim) {
66 StringRef name = "sparseDimSize";
67 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68 Type iTp = builder.getIndexType();
69 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70 .getResult(0);
71}
72
73/// Looks up a level-size by returning a statically-computed constant
74/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
77 Level lvl) {
78 // Only sparse tensors have "levels" to query.
79 assert(stt.hasEncoding());
80 // TODO: The following implementation only handles permutations;
81 // we'll need to generalize this to handle arbitrary AffineExpr.
82 //
83 // There's no need to assert `isPermutation` here: because
84 // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85 // which is all we care about (for supporting permutations).
86 const Dimension dim =
87 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88 const Size sz = stt.getDynamicDimSize(dim);
89 if (ShapedType::isStatic(sz))
90 return constantIndex(builder, loc, sz);
91 // If we cannot statically compute the size from the shape, then we
92 // must dynamically query it. (In principle we could also dynamically
93 // compute it, but since we already did so to construct the `tensor`
94 // in the first place, we might as well query rather than recompute.)
95 return genLvlSizeCall(builder, loc, tensor, lvl);
96}
97
98/// Looks up a dimension-size by returning a constant from the shape
99/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101/// of dense tensors).
102static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
104 Dimension dim) {
105 const Size sz = stt.getDynamicDimSize(dim);
106 if (ShapedType::isStatic(sz))
107 return constantIndex(builder, loc, sz);
108 if (stt.hasEncoding())
109 return genDimSizeCall(builder, loc, tensor, dim);
110 return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111}
112
113/// Populates the array with the dimension-sizes of the given tensor.
114static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
116 const Dimension dimRank = stt.getDimRank();
117 out.clear();
118 out.reserve(dimRank);
119 for (Dimension d = 0; d < dimRank; d++)
120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121}
122
123/// Returns an array with the dimension-sizes of the given tensor.
124/// If the *tensor* parameters is null, the tensor type is assumed to have a
125/// static shape.
126static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
128 Value tensor = Value()) {
130 fillDimSizes(builder, loc, stt, tensor, out);
131 return out;
132}
133
134/// Generates an uninitialized buffer of the given size and type,
135/// but returns it as type `memref<? x $tp>` (rather than as type
136/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137/// this buffer must be explicitly deallocated by client.
138static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140 return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz});
141}
142
143/// Generates a temporary buffer for the level-types of the given encoding.
144static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145 SparseTensorType stt) {
146 SmallVector<Value> lvlTypes;
147 lvlTypes.reserve(stt.getLvlRank());
148 for (const auto lt : stt.getEncoding().getLvlTypes())
149 lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150 return allocaBuffer(builder, loc, lvlTypes);
151}
152
153/// Extracts the bare (aligned) pointers that point to the tensor.
154static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155 Value tensor) {
156 auto buf = genToMemref(builder, loc, tensor);
157 return memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, buf);
158}
159
160/// Generates a temporary buffer for the level-types of the given encoding.
161static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162 ValueRange lvlTensors, Value valTensor) {
163 SmallVector<Value> lvlBarePtrs;
164 lvlBarePtrs.reserve(lvlTensors.size() + 1);
165 // Passing in lvl buffer pointers.
166 for (const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168
169 // Passing in value buffer pointers.
170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171 Value idxPtr = memref::ExtractAlignedPointerAsIndexOp::create(
172 builder, loc, allocaBuffer(builder, loc, lvlBarePtrs));
173 Value idxCast =
174 arith::IndexCastOp::create(builder, loc, builder.getI64Type(), idxPtr);
175 return LLVM::IntToPtrOp::create(builder, loc, getOpaquePointerType(builder),
176 idxCast);
177}
178
179/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180/// the "swiss army knife" method of the sparse runtime support library
181/// for materializing sparse tensors into the computation. This abstraction
182/// reduces the need for modifications when the API changes.
183class NewCallParams final {
184public:
185 /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186 NewCallParams(OpBuilder &builder, Location loc)
187 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188
189 /// Initializes all static parameters (i.e., those which indicate
190 /// type-level information such as the encoding and sizes), generating
191 /// MLIR buffers as needed, and returning `this` for method chaining.
192 NewCallParams &genBuffers(SparseTensorType stt,
193 ArrayRef<Value> dimSizesValues,
194 Value dimSizesBuffer = Value()) {
195 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196 // Sparsity annotations.
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199 params[kParamDimSizes] = dimSizesBuffer
200 ? dimSizesBuffer
201 : allocaBuffer(builder, loc, dimSizesValues);
202 SmallVector<Value> lvlSizesValues; // unused
203 params[kParamLvlSizes] = genMapBuffers(
204 builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206 // Secondary and primary types encoding.
207 const auto enc = stt.getEncoding();
208 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210 params[kParamValTp] =
211 constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212 // Return `this` for method chaining.
213 return *this;
214 }
215
216 /// Checks whether all the static parameters have been initialized.
217 bool isInitialized() const {
218 for (unsigned i = 0; i < kNumStaticParams; ++i)
219 if (!params[i])
220 return false;
221 return true;
222 }
223
224 /// Generates a function call, with the current static parameters
225 /// and the given dynamic arguments.
226 Value genNewCall(Action action, Value ptr = Value()) {
227 assert(isInitialized() && "Must initialize before genNewCall");
228 StringRef name = "newSparseTensor";
229 params[kParamAction] = constantAction(builder, loc, action);
230 params[kParamPtr] = ptr ? ptr : LLVM::ZeroOp::create(builder, loc, pTp);
231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232 .getResult(0);
233 }
234
235private:
236 static constexpr unsigned kNumStaticParams = 8;
237 static constexpr unsigned kNumDynamicParams = 2;
238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr unsigned kParamDimSizes = 0;
240 static constexpr unsigned kParamLvlSizes = 1;
241 static constexpr unsigned kParamLvlTypes = 2;
242 static constexpr unsigned kParamDim2Lvl = 3;
243 static constexpr unsigned kParamLvl2Dim = 4;
244 static constexpr unsigned kParamPosTp = 5;
245 static constexpr unsigned kParamCrdTp = 6;
246 static constexpr unsigned kParamValTp = 7;
247 static constexpr unsigned kParamAction = 8;
248 static constexpr unsigned kParamPtr = 9;
249
250 OpBuilder &builder;
251 Location loc;
252 Type pTp;
253 Value params[kNumParams];
254};
255
256/// Generates a call to obtain the values array.
257static Value genValuesCall(OpBuilder &builder, Location loc,
259 auto eltTp = stt.getElementType();
260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262 return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263 .getResult(0);
264}
265
266/// Generates a call to obtain the positions array.
267static Value genPositionsCall(OpBuilder &builder, Location loc,
269 Type posTp = stt.getPosType();
270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271 Value lvl = constantIndex(builder, loc, l);
272 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
275 .getResult(0);
276}
277
278/// Generates a call to obtain the coordinates array.
279static Value genCoordinatesCall(OpBuilder &builder, Location loc,
281 Type crdTp = stt.getCrdType();
282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283 Value lvl = constantIndex(builder, loc, l);
284 SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
287 .getResult(0);
288}
289
290/// Generates a call to obtain the coordinates array (AoS view).
291static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
293 Level l) {
294 Type crdTp = stt.getCrdType();
295 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296 Value lvl = constantIndex(builder, loc, l);
297 SmallString<25> name{"sparseCoordinatesBuffer",
299 return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
301 .getResult(0);
302}
303
304//===----------------------------------------------------------------------===//
305// Conversion rules.
306//===----------------------------------------------------------------------===//
307
308/// Sparse conversion rule for returns.
309class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
310public:
311 using OpConversionPattern::OpConversionPattern;
312 LogicalResult
313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
316 return success();
317 }
318};
319
320/// Sparse conversion rule for accessing level-sizes.
321class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
322public:
323 using OpConversionPattern::OpConversionPattern;
324 LogicalResult
325 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter) const override {
327 const auto stt = getSparseTensorType(op.getSource());
328 // Only rewrite sparse DimOp.
329 if (!stt.hasEncoding())
330 return failure();
331
332 // Only rewrite DimOp with constant index.
333 std::optional<int64_t> lvl = op.getConstantLvlIndex();
334
335 if (!lvl)
336 return failure();
337
338 // By now, if the level size is constant, the operation should have already
339 // been folded by LvlOp's folder, so we generate the call unconditionally.
340 Value src = adaptor.getOperands()[0];
341 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
342 return success();
343 }
344};
345
346/// Sparse conversion rule for trivial tensor casts.
347class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
348public:
349 using OpConversionPattern::OpConversionPattern;
350 LogicalResult
351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 // Only rewrite identically annotated source/dest.
354 auto encDst = getSparseTensorEncoding(op.getType());
355 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
356 if (!encDst || encDst != encSrc)
357 return failure();
358 rewriter.replaceOp(op, adaptor.getOperands());
359 return success();
360 }
361};
362
363class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
364public:
365 using OpConversionPattern::OpConversionPattern;
366 LogicalResult
367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const override {
369 // Simply fold the operation.
370 rewriter.replaceOp(op, adaptor.getSource());
371 return success();
372 }
373};
374
375/// Sparse conversion rule for the new operator.
376class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
377public:
378 using OpConversionPattern::OpConversionPattern;
379 LogicalResult
380 matchAndRewrite(NewOp op, OpAdaptor adaptor,
381 ConversionPatternRewriter &rewriter) const override {
382 Location loc = op.getLoc();
383 const auto stt = getSparseTensorType(op);
384 if (!stt.hasEncoding())
385 return failure();
386 // Construct the `reader` opening method calls.
387 SmallVector<Value> dimSizesValues;
388 Value dimSizesBuffer;
389 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
390 dimSizesValues, dimSizesBuffer);
391 // Use the `reader` to parse the file.
392 Value tensor = NewCallParams(rewriter, loc)
393 .genBuffers(stt, dimSizesValues, dimSizesBuffer)
394 .genNewCall(Action::kFromReader, reader);
395 // Free the memory for `reader`.
396 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
397 EmitCInterface::Off);
398 rewriter.replaceOp(op, tensor);
399 return success();
400 }
401};
402
403/// Sparse conversion rule for the alloc operator.
404/// TODO(springerm): remove when bufferization.alloc_tensor is gone
405class SparseTensorAllocConverter
406 : public OpConversionPattern<bufferization::AllocTensorOp> {
407public:
408 using OpConversionPattern::OpConversionPattern;
409 LogicalResult
410 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter) const override {
412 const auto stt = getSparseTensorType(op);
413 if (!stt.hasEncoding())
414 return failure();
415 if (op.getCopy())
416 return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
417 // Gather all dimension sizes as SSA values.
418 Location loc = op.getLoc();
419 const Dimension dimRank = stt.getDimRank();
420 SmallVector<Value> dimSizesValues;
421 dimSizesValues.reserve(dimRank);
422 unsigned operandCtr = 0;
423 for (Dimension d = 0; d < dimRank; d++) {
424 dimSizesValues.push_back(
425 stt.isDynamicDim(d)
426 ? adaptor.getOperands()[operandCtr++]
427 : constantIndex(rewriter, loc, op.getStaticSize(d)));
428 }
429 // Generate the call to construct empty tensor. The sizes are
430 // explicitly defined by the arguments to the alloc operator.
431 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
432 .genBuffers(stt, dimSizesValues)
433 .genNewCall(Action::kEmpty));
434 return success();
435 }
436};
437
438/// Sparse conversion rule for the empty tensor.
439class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
440public:
441 using OpConversionPattern::OpConversionPattern;
442 LogicalResult
443 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
444 ConversionPatternRewriter &rewriter) const override {
445 Location loc = op.getLoc();
446 const auto stt = getSparseTensorType(op);
447 if (!stt.hasEncoding())
448 return failure();
449 // Gather all dimension sizes as SSA values.
450 const Dimension dimRank = stt.getDimRank();
451 SmallVector<Value> dimSizesValues;
452 dimSizesValues.reserve(dimRank);
453 auto shape = op.getType().getShape();
454 unsigned operandCtr = 0;
455 for (Dimension d = 0; d < dimRank; d++) {
456 dimSizesValues.push_back(stt.isDynamicDim(d)
457 ? adaptor.getOperands()[operandCtr++]
458 : constantIndex(rewriter, loc, shape[d]));
459 }
460 // Generate the call to construct empty tensor. The sizes are
461 // explicitly defined by the arguments to the alloc operator.
462 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
463 .genBuffers(stt, dimSizesValues)
464 .genNewCall(Action::kEmpty));
465 return success();
466 }
467};
468
469/// Sparse conversion rule for the convert operator.
470class SparseTensorReorderCOOConverter
471 : public OpConversionPattern<ReorderCOOOp> {
472public:
473 using OpConversionPattern::OpConversionPattern;
474
475 LogicalResult
476 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 const Location loc = op->getLoc();
479 const auto srcTp = getSparseTensorType(op.getInputCoo());
480 const auto dstTp = getSparseTensorType(op);
481
482 const Value src = adaptor.getInputCoo();
483
484 NewCallParams params(rewriter, loc);
485 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
486 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
487 .genNewCall(Action::kSortCOOInPlace, src));
488
489 return success();
490 }
491};
492
493/// Sparse conversion rule for the dealloc operator.
494class SparseTensorDeallocConverter
495 : public OpConversionPattern<bufferization::DeallocTensorOp> {
496public:
497 using OpConversionPattern::OpConversionPattern;
498 LogicalResult
499 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
500 ConversionPatternRewriter &rewriter) const override {
501 if (!getSparseTensorType(op.getTensor()).hasEncoding())
502 return failure();
503 StringRef name = "delSparseTensor";
504 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
505 EmitCInterface::Off);
506 rewriter.eraseOp(op);
507 return success();
508 }
509};
510
511/// Sparse conversion rule for position accesses.
512class SparseTensorToPositionsConverter
513 : public OpConversionPattern<ToPositionsOp> {
514public:
515 using OpConversionPattern::OpConversionPattern;
516 LogicalResult
517 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
518 ConversionPatternRewriter &rewriter) const override {
519 auto stt = getSparseTensorType(op.getTensor());
520 auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521 adaptor.getTensor(), op.getLevel());
522 rewriter.replaceOp(op, poss);
523 return success();
524 }
525};
526
527/// Sparse conversion rule for coordinate accesses.
528class SparseTensorToCoordinatesConverter
529 : public OpConversionPattern<ToCoordinatesOp> {
530public:
531 using OpConversionPattern::OpConversionPattern;
532 LogicalResult
533 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter) const override {
535 const Location loc = op.getLoc();
536 auto stt = getSparseTensorType(op.getTensor());
537 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
538 op.getLevel());
539 // Cast the MemRef type to the type expected by the users, though these
540 // two types should be compatible at runtime.
541 if (op.getType() != crds.getType())
542 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
543 rewriter.replaceOp(op, crds);
544 return success();
545 }
546};
547
548/// Sparse conversion rule for coordinate accesses (AoS style).
549class SparseToCoordinatesBufferConverter
550 : public OpConversionPattern<ToCoordinatesBufferOp> {
551public:
552 using OpConversionPattern::OpConversionPattern;
553 LogicalResult
554 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter) const override {
556 const Location loc = op.getLoc();
557 auto stt = getSparseTensorType(op.getTensor());
558 auto crds = genCoordinatesBufferCall(
559 rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
560 // Cast the MemRef type to the type expected by the users, though these
561 // two types should be compatible at runtime.
562 if (op.getType() != crds.getType())
563 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
564 rewriter.replaceOp(op, crds);
565 return success();
566 }
567};
568
569/// Sparse conversion rule for value accesses.
570class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
571public:
572 using OpConversionPattern::OpConversionPattern;
573 LogicalResult
574 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter) const override {
576 auto stt = getSparseTensorType(op.getTensor());
577 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
578 rewriter.replaceOp(op, vals);
579 return success();
580 }
581};
582
583/// Sparse conversion rule for number of entries operator.
584class SparseNumberOfEntriesConverter
585 : public OpConversionPattern<NumberOfEntriesOp> {
586public:
587 using OpConversionPattern::OpConversionPattern;
588 LogicalResult
589 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
590 ConversionPatternRewriter &rewriter) const override {
591 // Query values array size for the actually stored values size.
592 auto stt = getSparseTensorType(op.getTensor());
593 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
594 auto zero = constantIndex(rewriter, op.getLoc(), 0);
595 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
596 return success();
597 }
598};
599
600/// Sparse conversion rule for tensor rematerialization.
601class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
602public:
603 using OpConversionPattern::OpConversionPattern;
604 LogicalResult
605 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter) const override {
607 if (op.getHasInserts()) {
608 // Finalize any pending insertions.
609 StringRef name = "endLexInsert";
610 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
611 EmitCInterface::Off);
612 }
613 rewriter.replaceOp(op, adaptor.getOperands());
614 return success();
615 }
616};
617
618/// Sparse conversion rule for the insertion operator.
619class SparseTensorInsertConverter
620 : public OpConversionPattern<tensor::InsertOp> {
621public:
622 using OpConversionPattern::OpConversionPattern;
623 LogicalResult
624 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
625 ConversionPatternRewriter &rewriter) const override {
626 // Note that the current regime only allows for strict lexicographic
627 // coordinate order. All values are passed by reference through stack
628 // allocated memrefs.
629 Location loc = op->getLoc();
630 const auto stt = getSparseTensorType(op.getDest());
631
632 // Dense tensor insertion.
633 if (!stt.hasEncoding())
634 return failure();
635
636 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
637 const auto elemTp = stt.getElementType();
638 const Level lvlRank = stt.getLvlRank();
639 Value lvlCoords, vref;
640 {
641 OpBuilder::InsertionGuard guard(rewriter);
642 Operation *loop = op;
643 // Finds the outermost loop.
644 while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
645 loop = l;
646
647 if (llvm::isa<LoopLikeOpInterface>(loop)) {
648 // Hoists alloca outside the loop to avoid stack overflow.
649 rewriter.setInsertionPoint(loop);
650 }
651 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
652 vref = genAllocaScalar(rewriter, loc, elemTp);
653 }
654 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
655 memref::StoreOp::create(rewriter, loc, adaptor.getScalar(), vref);
656 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
657 createFuncCall(rewriter, loc, name, {},
658 {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
659 rewriter.replaceOp(op, adaptor.getDest());
660 return success();
661 }
662};
663
664/// Sparse conversion rule for the expand operator.
665class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
666public:
667 using OpConversionPattern::OpConversionPattern;
668 LogicalResult
669 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter) const override {
671 Location loc = op->getLoc();
672 const auto srcTp = getSparseTensorType(op.getTensor());
673 Type eltType = srcTp.getElementType();
674 Type boolType = rewriter.getIntegerType(1);
675 Type idxType = rewriter.getIndexType();
676 // All initialization should be done on entry of the loop nest.
677 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
678 // Get the cardinality of valid coordinates for the innermost level.
679 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680 srcTp.getLvlRank() - 1);
681 // Allocate temporary buffers for values, filled-switch, and coordinates.
682 // We do not use stack buffers for this, since the expanded size may
683 // be rather large (as it envelops a single expanded dense dimension).
684 Value values = genAlloc(rewriter, loc, sz, eltType);
685 Value filled = genAlloc(rewriter, loc, sz, boolType);
686 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
687 Value zero = constantZero(rewriter, loc, idxType);
688 // Reset the values/filled-switch to all-zero/false. Note that this
689 // introduces an O(N) operation into the computation, but this reset
690 // operation is amortized over the innermost loops for the access
691 // pattern expansion. As noted in the operation doc, we would like
692 // to amortize this setup cost even between kernels.
693 linalg::FillOp::create(rewriter, loc,
694 ValueRange{constantZero(rewriter, loc, eltType)},
695 ValueRange{values});
696 linalg::FillOp::create(rewriter, loc,
697 ValueRange{constantZero(rewriter, loc, boolType)},
698 ValueRange{filled});
699 // Replace expansion op with these buffers and initial coordinate.
700 assert(op.getNumResults() == 4);
701 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
702 return success();
703 }
704};
705
706/// Sparse conversion rule for the compress operator.
707class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
708public:
709 using OpConversionPattern::OpConversionPattern;
710 LogicalResult
711 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter) const override {
713 Location loc = op->getLoc();
714 // Note that this method call resets the values/filled-switch back to
715 // all-zero/false by only iterating over the set elements, so the
716 // complexity remains proportional to the sparsity of the expanded
717 // access pattern.
718 Value values = adaptor.getValues();
719 Value filled = adaptor.getFilled();
720 Value added = adaptor.getAdded();
721 Value count = adaptor.getCount();
722 Value tensor = adaptor.getTensor();
723 const auto stt = getSparseTensorType(op.getTensor());
724 const Type elemTp = stt.getElementType();
725 const Level lvlRank = stt.getLvlRank();
726 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
727 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
728 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
729 createFuncCall(rewriter, loc, name, {},
730 {tensor, lvlCoords, values, filled, added, count},
731 EmitCInterface::On);
732 Operation *parent = getTop(op);
733 rewriter.setInsertionPointAfter(parent);
734 rewriter.replaceOp(op, adaptor.getTensor());
735 // Deallocate the buffers on exit of the loop nest.
736 memref::DeallocOp::create(rewriter, loc, values);
737 memref::DeallocOp::create(rewriter, loc, filled);
738 memref::DeallocOp::create(rewriter, loc, added);
739 return success();
740 }
741};
742
743/// Sparse conversion rule for the sparse_tensor.assemble operator.
744class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
745public:
746 using OpConversionPattern::OpConversionPattern;
747 LogicalResult
748 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
749 ConversionPatternRewriter &rewriter) const override {
750 const Location loc = op->getLoc();
751 const auto dstTp = getSparseTensorType(op.getResult());
752 assert(dstTp.hasStaticDimShape());
753 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
754 // Use a library method to transfer the external buffers from
755 // clients to the internal SparseTensorStorage. Since we cannot
756 // assume clients transfer ownership of the buffers, this method
757 // will copy all data over into a new SparseTensorStorage.
758 Value dst =
759 NewCallParams(rewriter, loc)
760 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
761 .genNewCall(Action::kPack,
762 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763 adaptor.getValues()));
764 rewriter.replaceOp(op, dst);
765 return success();
766 }
767};
768
769/// Sparse conversion rule for the sparse_tensor.disassemble operator.
770/// Note that the current implementation simply exposes the buffers to
771/// the external client. This assumes the client only reads the buffers
772/// (usually copying it to the external data structures, such as numpy
773/// arrays). The semantics of the disassemble operation technically
774/// require that the copying is done here already using the out-levels
775/// and out-values clause.
776class SparseTensorDisassembleConverter
777 : public OpConversionPattern<DisassembleOp> {
778public:
779 using OpConversionPattern::OpConversionPattern;
780 LogicalResult
781 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter) const override {
783 Location loc = op->getLoc();
784 auto stt = getSparseTensorType(op.getTensor());
785 SmallVector<Value> retVal;
786 SmallVector<Value> retLen;
787 // Get the positions and coordinates buffers.
788 const Level lvlRank = stt.getLvlRank();
789 Level trailCOOLen = 0;
790 for (Level l = 0; l < lvlRank; l++) {
791 if (!stt.isUniqueLvl(l) &&
792 (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
793 // A `(loose)compressed_nu` level marks the start of trailing COO
794 // start level. Since the target coordinate buffer used for trailing
795 // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
796 // scheme, we cannot simply use the internal buffers.
797 trailCOOLen = lvlRank - l;
798 break;
799 }
800 if (stt.isWithPos(l)) {
801 auto poss =
802 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
803 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
804 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805 retVal.push_back(poss);
806 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
807 }
808 if (stt.isWithCrd(l)) {
809 auto crds =
810 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
811 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
812 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813 retVal.push_back(crds);
814 retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
815 }
816 }
817 // Handle AoS vs. SoA mismatch for COO.
818 if (trailCOOLen != 0) {
819 uint64_t cooStartLvl = lvlRank - trailCOOLen;
820 assert(!stt.isUniqueLvl(cooStartLvl) &&
821 (stt.isCompressedLvl(cooStartLvl) ||
822 stt.isLooseCompressedLvl(cooStartLvl)));
823 // Positions.
824 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
825 cooStartLvl);
826 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
827 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828 retVal.push_back(poss);
829 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
830 // Coordinates, copied over with:
831 // for (i = 0; i < crdLen; i++)
832 // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
833 auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
835 cooStartLvl);
836 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
837 cooStartLvl + 1);
838 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
839 auto two = constantIndex(rewriter, loc, 2);
840 auto bufLen = arith::MulIOp::create(rewriter, loc, crdLen, two);
841 Type indexType = rewriter.getIndexType();
842 auto zero = constantZero(rewriter, loc, indexType);
843 auto one = constantOne(rewriter, loc, indexType);
844 scf::ForOp forOp = scf::ForOp::create(rewriter, loc, zero, crdLen, one);
845 auto idx = forOp.getInductionVar();
846 rewriter.setInsertionPointToStart(forOp.getBody());
847 auto c0 = memref::LoadOp::create(rewriter, loc, crds0, idx);
848 auto c1 = memref::LoadOp::create(rewriter, loc, crds1, idx);
849 SmallVector<Value> args;
850 args.push_back(idx);
851 args.push_back(zero);
852 memref::StoreOp::create(rewriter, loc, c0, buf, args);
853 args[1] = one;
854 memref::StoreOp::create(rewriter, loc, c1, buf, args);
855 rewriter.setInsertionPointAfter(forOp);
856 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857 retVal.push_back(buf);
858 retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
859 }
860 // Get the values buffer last.
861 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862 auto valLenTp = op.getValLen().getType();
863 auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
864 retVal.push_back(vals);
865 retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
866
867 // Converts MemRefs back to Tensors.
868 assert(retVal.size() + retLen.size() == op.getNumResults());
869 for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870 auto tensor = bufferization::ToTensorOp::create(
871 rewriter, loc,
872 memref::getTensorTypeFromMemRefType(retVal[i].getType()), retVal[i]);
873 retVal[i] =
874 tensor::CastOp::create(rewriter, loc, op.getResultTypes()[i], tensor);
875 }
876
877 // Appends the actual memory length used in each buffer returned.
878 retVal.append(retLen.begin(), retLen.end());
879 rewriter.replaceOp(op, retVal);
880 return success();
881 }
882};
883
884struct SparseHasRuntimeLibraryConverter
885 : public OpConversionPattern<HasRuntimeLibraryOp> {
886 using OpConversionPattern::OpConversionPattern;
887 LogicalResult
888 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
889 ConversionPatternRewriter &rewriter) const override {
890 auto i1Type = rewriter.getI1Type();
891 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
892 op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
893 return success();
894 }
895};
896
897} // namespace
898
899//===----------------------------------------------------------------------===//
900// Sparse tensor type conversion into opaque pointer.
901//===----------------------------------------------------------------------===//
902
904 addConversion([](Type type) { return type; });
905 addConversion(convertSparseTensorTypes);
906}
907
908//===----------------------------------------------------------------------===//
909// Public method for populating conversion rules.
910//===----------------------------------------------------------------------===//
911
912/// Populates the given patterns list with conversion rules required for
913/// the sparsification of linear algebra operations.
915 const TypeConverter &typeConverter, RewritePatternSet &patterns) {
917 .add<SparseReturnConverter, SparseTensorLvlOpConverter,
918 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
919 SparseTensorAllocConverter, SparseTensorEmptyConverter,
920 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
921 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
922 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
923 SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
924 SparseTensorInsertConverter, SparseTensorExpandConverter,
925 SparseTensorCompressConverter, SparseTensorAssembleConverter,
926 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
927 typeConverter, patterns.getContext());
928}
return success()
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
IntegerType getI64Type()
Definition Builders.cpp:65
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:60
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Action
The actions performed by @newSparseTensor.
Definition Enums.h:146
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
const FrozenRewritePatternSet & patterns