MLIR 22.0.0git
SparseAssembler.cpp
Go to the documentation of this file.
1//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "llvm/Support/FormatVariadic.h"
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20//===----------------------------------------------------------------------===//
21// Helper methods.
22//===----------------------------------------------------------------------===//
23
24// Convert type range to new types range, with sparse tensors externalized.
25static void convTypes(bool &hasAnnotation, TypeRange types,
27 SmallVectorImpl<Type> *extraTypes, bool directOut) {
28 for (auto type : types) {
29 // All "dense" data passes through unmodified.
30 if (!getSparseTensorEncoding(type)) {
31 convTypes.push_back(type);
32 continue;
33 }
34 hasAnnotation = true;
35
36 // Convert the external representations of the pos/crd/val arrays.
37 const SparseTensorType stt(cast<RankedTensorType>(type));
39 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
42 if (kind == SparseTensorFieldKind::PosMemRef ||
43 kind == SparseTensorFieldKind::CrdMemRef ||
44 kind == SparseTensorFieldKind::ValMemRef) {
45 auto rtp = cast<ShapedType>(t);
46 if (!directOut) {
47 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
48 if (extraTypes)
49 extraTypes->push_back(rtp);
50 }
51 convTypes.push_back(rtp);
52 }
53 return true;
54 });
55 }
56}
57
58// Convert input and output values to [dis]assemble ops for sparse tensors.
59static void convVals(OpBuilder &builder, Location loc, TypeRange types,
60 ValueRange fromVals, ValueRange extraVals,
61 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62 bool directOut) {
63 unsigned idx = 0;
64 for (auto type : types) {
65 // All "dense" data passes through unmodified.
66 if (!getSparseTensorEncoding(type)) {
67 toVals.push_back(fromVals[idx++]);
68 continue;
69 }
70 // Handle sparse data.
71 auto rtp = cast<RankedTensorType>(type);
72 const SparseTensorType stt(rtp);
73 SmallVector<Value> inputs;
74 SmallVector<Type> retTypes;
75 SmallVector<Type> cntTypes;
76 if (!isIn)
77 inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
78
79 // Collect the external representations of the pos/crd/val arrays.
82 Level lv, LevelType) {
83 if (kind == SparseTensorFieldKind::PosMemRef ||
84 kind == SparseTensorFieldKind::CrdMemRef ||
85 kind == SparseTensorFieldKind::ValMemRef) {
86 if (isIn) {
87 inputs.push_back(fromVals[idx++]);
88 } else if (directOut) {
89 Value mem;
90 if (kind == SparseTensorFieldKind::PosMemRef)
91 mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0],
92 lv);
93 else if (kind == SparseTensorFieldKind::CrdMemRef)
94 mem = sparse_tensor::ToCoordinatesOp::create(builder, loc,
95 inputs[0], lv);
96 else
97 mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]);
98 toVals.push_back(mem);
99 } else {
100 ShapedType rtp = cast<ShapedType>(t);
101 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
102 inputs.push_back(extraVals[extra++]);
103 retTypes.push_back(rtp);
104 cntTypes.push_back(builder.getIndexType());
105 }
106 }
107 return true;
108 });
109
110 if (isIn) {
111 // Assemble multiple inputs into a single sparse tensor.
112 auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs);
113 toVals.push_back(a.getResult());
114 } else if (!directOut) {
115 // Disassemble a single sparse input into multiple outputs.
116 // Note that this includes the counters, which are dropped.
117 unsigned len = retTypes.size();
118 retTypes.append(cntTypes);
119 auto d =
120 sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs);
121 for (unsigned i = 0; i < len; i++)
122 toVals.push_back(d.getResult(i));
123 }
124 }
125}
126
127//===----------------------------------------------------------------------===//
128// Rewriting rules.
129//===----------------------------------------------------------------------===//
130
131namespace {
132
133// A rewriting rules that converts public entry methods that use sparse tensors
134// as input parameters and/or output return values into wrapper methods that
135// [dis]assemble the individual tensors that constitute the actual storage used
136// externally into MLIR sparse tensors before calling the original method.
137//
138// In particular, each sparse tensor input
139//
140// void foo(..., t, ...) { }
141//
142// makes the original foo() internal and adds the following wrapper method
143//
144// void foo(..., t1..tn, ...) {
145// t = assemble t1..tn
146// _internal_foo(..., t, ...)
147// }
148//
149// and likewise, each output tensor
150//
151// ... T ... bar(...) { return ..., t, ...; }
152//
153// makes the original bar() internal and adds the following wrapper method
154//
155// ... T1..TN ... bar(..., t1'..tn') {
156// ..., t, ... = _internal_bar(...)
157// t1..tn = disassemble t, t1'..tn'
158// return ..., t1..tn, ...
159// }
160//
161// (with a direct-out variant without the disassemble).
162//
163struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
165
166 SparseFuncAssembler(MLIRContext *context, bool dO)
167 : OpRewritePattern(context), directOut(dO) {}
168
169 LogicalResult matchAndRewrite(func::FuncOp funcOp,
170 PatternRewriter &rewriter) const override {
171 // Only rewrite public entry methods.
172 if (funcOp.isPrivate())
173 return failure();
174
175 // Translate sparse tensor types to external types.
176 SmallVector<Type> inputTypes;
177 SmallVector<Type> outputTypes;
178 SmallVector<Type> extraTypes;
179 bool hasAnnotation = false;
180 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
181 false);
182 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
183 directOut);
184
185 // Only sparse inputs or outputs need a wrapper method.
186 if (!hasAnnotation)
187 return failure();
188
189 // Modify the original method into an internal, private method.
190 auto orgName = funcOp.getName();
191 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
192 funcOp.setName(wrapper);
193 funcOp.setPrivate();
194
195 // Start the new public wrapper method with original name.
196 Location loc = funcOp.getLoc();
197 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
198 MLIRContext *context = modOp.getContext();
199 OpBuilder moduleBuilder(modOp.getBodyRegion());
200 unsigned extra = inputTypes.size();
201 inputTypes.append(extraTypes);
202 auto func = func::FuncOp::create(
203 moduleBuilder, loc, orgName,
204 FunctionType::get(context, inputTypes, outputTypes));
205 func.setPublic();
206
207 // Construct new wrapper method body.
208 OpBuilder::InsertionGuard insertionGuard(rewriter);
209 Block *body = func.addEntryBlock();
210 rewriter.setInsertionPointToStart(body);
211
212 // Convert inputs.
213 SmallVector<Value> inputs;
214 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
215 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
216
217 // Call the original, now private method. A subsequent inlining pass can
218 // determine whether cloning the method body in place is worthwhile.
219 auto org = SymbolRefAttr::get(context, wrapper);
220 auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(),
221 org, inputs);
222
223 // Convert outputs and return.
224 SmallVector<Value> outputs;
225 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
226 body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
227 func::ReturnOp::create(rewriter, loc, outputs);
228
229 // Finally, migrate a potential c-interface property.
230 if (funcOp->getAttrOfType<UnitAttr>(
231 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
232 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
233 UnitAttr::get(context));
234 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
235 }
236 return success();
237 }
238
239private:
240 const bool directOut;
241};
242
243} // namespace
244
245//===----------------------------------------------------------------------===//
246// Public method for populating conversion rules.
247//===----------------------------------------------------------------------===//
248
250 bool directOut) {
251 patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
252}
return success()
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
BlockArgListType getArguments()
Definition Block.h:87
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238