MLIR 23.0.0git
SparseAssembler.cpp
Go to the documentation of this file.
1//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "llvm/Support/FormatVariadic.h"
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20//===----------------------------------------------------------------------===//
21// Helper methods.
22//===----------------------------------------------------------------------===//
23
24// Convert type range to new types range, with sparse tensors externalized.
25static void convTypes(bool &hasAnnotation, TypeRange types,
27 SmallVectorImpl<Type> *extraTypes, bool directOut) {
28 for (auto type : types) {
29 // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
30 // also returns non-null for StorageSpecifierType (which is not a
31 // RankedTensorType), so we must check isa<RankedTensorType> as well.
32 if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
33 convTypes.push_back(type);
34 continue;
35 }
36 hasAnnotation = true;
37
38 // Convert the external representations of the pos/crd/val arrays.
39 const SparseTensorType stt(cast<RankedTensorType>(type));
41 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
44 if (kind == SparseTensorFieldKind::PosMemRef ||
45 kind == SparseTensorFieldKind::CrdMemRef ||
46 kind == SparseTensorFieldKind::ValMemRef) {
47 auto rtp = cast<ShapedType>(t);
48 if (!directOut) {
49 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
50 if (extraTypes)
51 extraTypes->push_back(rtp);
52 }
53 convTypes.push_back(rtp);
54 }
55 return true;
56 });
57 }
58}
59
60// Convert input and output values to [dis]assemble ops for sparse tensors.
61static void convVals(OpBuilder &builder, Location loc, TypeRange types,
62 ValueRange fromVals, ValueRange extraVals,
63 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
64 bool directOut) {
65 unsigned idx = 0;
66 for (auto type : types) {
67 // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
68 // also returns non-null for StorageSpecifierType (which is not a
69 // RankedTensorType), so we must check isa<RankedTensorType> as well.
70 if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
71 toVals.push_back(fromVals[idx++]);
72 continue;
73 }
74 // Handle sparse data.
75 auto rtp = cast<RankedTensorType>(type);
76 const SparseTensorType stt(rtp);
77 SmallVector<Value> inputs;
78 SmallVector<Type> retTypes;
79 SmallVector<Type> cntTypes;
80 if (!isIn)
81 inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
82
83 // Collect the external representations of the pos/crd/val arrays.
86 Level lv, LevelType) {
87 if (kind == SparseTensorFieldKind::PosMemRef ||
88 kind == SparseTensorFieldKind::CrdMemRef ||
89 kind == SparseTensorFieldKind::ValMemRef) {
90 if (isIn) {
91 inputs.push_back(fromVals[idx++]);
92 } else if (directOut) {
93 Value mem;
94 if (kind == SparseTensorFieldKind::PosMemRef)
95 mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0],
96 lv);
97 else if (kind == SparseTensorFieldKind::CrdMemRef)
98 mem = sparse_tensor::ToCoordinatesOp::create(builder, loc,
99 inputs[0], lv);
100 else
101 mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]);
102 toVals.push_back(mem);
103 } else {
104 ShapedType rtp = cast<ShapedType>(t);
105 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
106 inputs.push_back(extraVals[extra++]);
107 retTypes.push_back(rtp);
108 cntTypes.push_back(builder.getIndexType());
109 }
110 }
111 return true;
112 });
113
114 if (isIn) {
115 // Assemble multiple inputs into a single sparse tensor.
116 auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs);
117 toVals.push_back(a.getResult());
118 } else if (!directOut) {
119 // Disassemble a single sparse input into multiple outputs.
120 // Note that this includes the counters, which are dropped.
121 unsigned len = retTypes.size();
122 retTypes.append(cntTypes);
123 auto d =
124 sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs);
125 for (unsigned i = 0; i < len; i++)
126 toVals.push_back(d.getResult(i));
127 }
128 }
129}
130
131//===----------------------------------------------------------------------===//
132// Rewriting rules.
133//===----------------------------------------------------------------------===//
134
135namespace {
136
137// A rewriting rules that converts public entry methods that use sparse tensors
138// as input parameters and/or output return values into wrapper methods that
139// [dis]assemble the individual tensors that constitute the actual storage used
140// externally into MLIR sparse tensors before calling the original method.
141//
142// In particular, each sparse tensor input
143//
144// void foo(..., t, ...) { }
145//
146// makes the original foo() internal and adds the following wrapper method
147//
148// void foo(..., t1..tn, ...) {
149// t = assemble t1..tn
150// _internal_foo(..., t, ...)
151// }
152//
153// and likewise, each output tensor
154//
155// ... T ... bar(...) { return ..., t, ...; }
156//
157// makes the original bar() internal and adds the following wrapper method
158//
159// ... T1..TN ... bar(..., t1'..tn') {
160// ..., t, ... = _internal_bar(...)
161// t1..tn = disassemble t, t1'..tn'
162// return ..., t1..tn, ...
163// }
164//
165// (with a direct-out variant without the disassemble).
166//
167struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
169
170 SparseFuncAssembler(MLIRContext *context, bool dO)
171 : OpRewritePattern(context), directOut(dO) {}
172
173 LogicalResult matchAndRewrite(func::FuncOp funcOp,
174 PatternRewriter &rewriter) const override {
175 // Only rewrite public entry methods.
176 if (funcOp.isPrivate())
177 return failure();
178
179 // Translate sparse tensor types to external types.
180 SmallVector<Type> inputTypes;
181 SmallVector<Type> outputTypes;
182 SmallVector<Type> extraTypes;
183 bool hasAnnotation = false;
184 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
185 false);
186 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
187 directOut);
188
189 // Only sparse inputs or outputs need a wrapper method.
190 if (!hasAnnotation)
191 return failure();
192
193 // Modify the original method into an internal, private method.
194 auto orgName = funcOp.getName();
195 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
196 funcOp.setName(wrapper);
197 funcOp.setPrivate();
198
199 // Start the new public wrapper method with original name.
200 Location loc = funcOp.getLoc();
201 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
202 MLIRContext *context = modOp.getContext();
203 OpBuilder moduleBuilder(modOp.getBodyRegion());
204 unsigned extra = inputTypes.size();
205 inputTypes.append(extraTypes);
206 auto func = func::FuncOp::create(
207 moduleBuilder, loc, orgName,
208 FunctionType::get(context, inputTypes, outputTypes));
209 func.setPublic();
210
211 // Construct new wrapper method body.
212 OpBuilder::InsertionGuard insertionGuard(rewriter);
213 Block *body = func.addEntryBlock();
214 rewriter.setInsertionPointToStart(body);
215
216 // Convert inputs.
217 SmallVector<Value> inputs;
218 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
219 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
220
221 // Call the original, now private method. A subsequent inlining pass can
222 // determine whether cloning the method body in place is worthwhile.
223 auto org = SymbolRefAttr::get(context, wrapper);
224 auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(),
225 org, inputs);
226
227 // Convert outputs and return.
228 SmallVector<Value> outputs;
229 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
230 body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
231 func::ReturnOp::create(rewriter, loc, outputs);
232
233 // Finally, migrate a potential c-interface property.
234 if (funcOp->getAttrOfType<UnitAttr>(
235 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
236 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
237 UnitAttr::get(context));
238 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
239 }
240 return success();
241 }
242
243private:
244 const bool directOut;
245};
246
247} // namespace
248
249//===----------------------------------------------------------------------===//
250// Public method for populating conversion rules.
251//===----------------------------------------------------------------------===//
252
254 bool directOut) {
255 patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
256}
return success()
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
BlockArgListType getArguments()
Definition Block.h:97
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238