MLIR 23.0.0git
SparseAssembler.cpp
Go to the documentation of this file.
1//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "llvm/Support/FormatVariadic.h"
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20//===----------------------------------------------------------------------===//
21// Helper methods.
22//===----------------------------------------------------------------------===//
23
24// Convert type range to new types range, with sparse tensors externalized.
25static void convTypes(bool &hasAnnotation, TypeRange types,
27 SmallVectorImpl<Type> *extraTypes, bool directOut) {
28 for (auto type : types) {
29 // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
30 // also returns non-null for StorageSpecifierType (which is not a
31 // RankedTensorType), so we must check isa<RankedTensorType> as well.
32 if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
33 convTypes.push_back(type);
34 continue;
35 }
36 hasAnnotation = true;
37
38 // Convert the external representations of the pos/crd/val arrays.
39 const SparseTensorType stt(cast<RankedTensorType>(type));
41 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
44 if (kind == SparseTensorFieldKind::PosMemRef ||
45 kind == SparseTensorFieldKind::CrdMemRef ||
46 kind == SparseTensorFieldKind::ValMemRef) {
47 auto rtp = cast<ShapedType>(t);
48 if (!directOut) {
49 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
50 if (extraTypes)
51 extraTypes->push_back(rtp);
52 }
53 convTypes.push_back(rtp);
54 }
55 return true;
56 });
57 }
58}
59
60// Convert input and output values to [dis]assemble ops for sparse tensors.
61static void convVals(OpBuilder &builder, Location loc, TypeRange types,
62 ValueRange fromVals, ValueRange extraVals,
63 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
64 bool directOut) {
65 unsigned idx = 0;
66 for (auto type : types) {
67 // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
68 // also returns non-null for StorageSpecifierType (which is not a
69 // RankedTensorType), so we must check isa<RankedTensorType> as well.
70 if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
71 toVals.push_back(fromVals[idx++]);
72 continue;
73 }
74 // Handle sparse data.
75 auto rtp = cast<RankedTensorType>(type);
76 const SparseTensorType stt(rtp);
77 SmallVector<Value> inputs;
78 SmallVector<Type> retTypes;
79 SmallVector<Type> cntTypes;
80 if (!isIn)
81 inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
82
83 // Collect the external representations of the pos/crd/val arrays.
86 Level lv, LevelType) {
87 if (kind == SparseTensorFieldKind::PosMemRef ||
88 kind == SparseTensorFieldKind::CrdMemRef ||
89 kind == SparseTensorFieldKind::ValMemRef) {
90 if (isIn) {
91 inputs.push_back(fromVals[idx++]);
92 } else if (directOut) {
93 Value mem;
94 if (kind == SparseTensorFieldKind::PosMemRef)
95 mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0],
96 lv);
97 else if (kind == SparseTensorFieldKind::CrdMemRef)
98 mem = sparse_tensor::ToCoordinatesOp::create(builder, loc,
99 inputs[0], lv);
100 else
101 mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]);
102 toVals.push_back(mem);
103 } else {
104 ShapedType rtp = cast<ShapedType>(t);
105 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
106 inputs.push_back(extraVals[extra++]);
107 retTypes.push_back(rtp);
108 cntTypes.push_back(builder.getIndexType());
109 }
110 }
111 return true;
112 });
113
114 if (isIn) {
115 // Assemble multiple inputs into a single sparse tensor.
116 auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs);
117 toVals.push_back(a.getResult());
118 } else if (!directOut) {
119 // Disassemble a single sparse input into multiple outputs.
120 // Note that this includes the counters, which are dropped.
121 unsigned len = retTypes.size();
122 retTypes.append(cntTypes);
123 auto d =
124 sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs);
125 for (unsigned i = 0; i < len; i++)
126 toVals.push_back(d.getResult(i));
127 }
128 }
129}
130
131//===----------------------------------------------------------------------===//
132// Rewriting rules.
133//===----------------------------------------------------------------------===//
134
135namespace {
136
137// A rewriting rules that converts public entry methods that use sparse tensors
138// as input parameters and/or output return values into wrapper methods that
139// [dis]assemble the individual tensors that constitute the actual storage used
140// externally into MLIR sparse tensors before calling the original method.
141//
142// In particular, each sparse tensor input
143//
144// void foo(..., t, ...) { }
145//
146// makes the original foo() internal and adds the following wrapper method
147//
148// void foo(..., t1..tn, ...) {
149// t = assemble t1..tn
150// _internal_foo(..., t, ...)
151// }
152//
153// and likewise, each output tensor
154//
155// ... T ... bar(...) { return ..., t, ...; }
156//
157// makes the original bar() internal and adds the following wrapper method
158//
159// ... T1..TN ... bar(..., t1'..tn') {
160// ..., t, ... = _internal_bar(...)
161// t1..tn = disassemble t, t1'..tn'
162// return ..., t1..tn, ...
163// }
164//
165// (with a direct-out variant without the disassemble).
166//
167struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
169
170 SparseFuncAssembler(MLIRContext *context, bool dO)
171 : OpRewritePattern(context), directOut(dO) {}
172
173 LogicalResult matchAndRewrite(func::FuncOp funcOp,
174 PatternRewriter &rewriter) const override {
175 // Only rewrite public entry methods.
176 if (funcOp.isPrivate())
177 return failure();
178
179 // Translate sparse tensor types to external types.
180 SmallVector<Type> inputTypes;
181 SmallVector<Type> outputTypes;
182 SmallVector<Type> extraTypes;
183 bool hasAnnotation = false;
184 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
185 false);
186 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
187 directOut);
188
189 // Only sparse inputs or outputs need a wrapper method.
190 if (!hasAnnotation)
191 return failure();
192
193 // Modify the original method into an internal, private method.
194 auto orgName = funcOp.getName();
195 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
196 rewriter.modifyOpInPlace(funcOp, [&]() {
197 funcOp.setName(wrapper);
198 funcOp.setPrivate();
199 });
200
201 // Start the new public wrapper method with original name.
202 Location loc = funcOp.getLoc();
203 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
204 MLIRContext *context = modOp.getContext();
205 OpBuilder moduleBuilder(modOp.getBodyRegion());
206 unsigned extra = inputTypes.size();
207 inputTypes.append(extraTypes);
208 auto func = func::FuncOp::create(
209 moduleBuilder, loc, orgName,
210 FunctionType::get(context, inputTypes, outputTypes));
211 func.setPublic();
212
213 // Construct new wrapper method body.
214 OpBuilder::InsertionGuard insertionGuard(rewriter);
215 Block *body = func.addEntryBlock();
216 rewriter.setInsertionPointToStart(body);
217
218 // Convert inputs.
219 SmallVector<Value> inputs;
220 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
221 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
222
223 // Call the original, now private method. A subsequent inlining pass can
224 // determine whether cloning the method body in place is worthwhile.
225 auto org = SymbolRefAttr::get(context, wrapper);
226 auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(),
227 org, inputs);
228
229 // Convert outputs and return.
230 SmallVector<Value> outputs;
231 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
232 body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
233 func::ReturnOp::create(rewriter, loc, outputs);
234
235 // Finally, migrate a potential c-interface property.
236 if (funcOp->getAttrOfType<UnitAttr>(
237 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
238 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
239 UnitAttr::get(context));
240 rewriter.modifyOpInPlace(funcOp, [&]() {
241 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
242 });
243 }
244 return success();
245 }
246
247private:
248 const bool directOut;
249};
250
251} // namespace
252
253//===----------------------------------------------------------------------===//
254// Public method for populating conversion rules.
255//===----------------------------------------------------------------------===//
256
258 bool directOut) {
259 patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
260}
return success()
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
BlockArgListType getArguments()
Definition Block.h:97
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238