MLIR  20.0.0git
SparseAssembler.cpp
Go to the documentation of this file.
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 using namespace sparse_tensor;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper methods.
24 //===----------------------------------------------------------------------===//
25 
26 // Convert type range to new types range, with sparse tensors externalized.
27 static void convTypes(bool &hasAnnotation, TypeRange types,
29  SmallVectorImpl<Type> *extraTypes, bool directOut) {
30  for (auto type : types) {
31  // All "dense" data passes through unmodified.
32  if (!getSparseTensorEncoding(type)) {
33  convTypes.push_back(type);
34  continue;
35  }
36  hasAnnotation = true;
37 
38  // Convert the external representations of the pos/crd/val arrays.
39  const SparseTensorType stt(cast<RankedTensorType>(type));
41  stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
43  Level, LevelType) {
47  auto rtp = cast<ShapedType>(t);
48  if (!directOut) {
49  rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
50  if (extraTypes)
51  extraTypes->push_back(rtp);
52  }
53  convTypes.push_back(rtp);
54  }
55  return true;
56  });
57  }
58 }
59 
60 // Convert input and output values to [dis]assemble ops for sparse tensors.
61 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
62  ValueRange fromVals, ValueRange extraVals,
63  SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
64  bool directOut) {
65  unsigned idx = 0;
66  for (auto type : types) {
67  // All "dense" data passes through unmodified.
68  if (!getSparseTensorEncoding(type)) {
69  toVals.push_back(fromVals[idx++]);
70  continue;
71  }
72  // Handle sparse data.
73  auto rtp = cast<RankedTensorType>(type);
74  const SparseTensorType stt(rtp);
75  SmallVector<Value> inputs;
76  SmallVector<Type> retTypes;
77  SmallVector<Type> cntTypes;
78  if (!isIn)
79  inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
80 
81  // Collect the external representations of the pos/crd/val arrays.
84  Level lv, LevelType) {
88  if (isIn) {
89  inputs.push_back(fromVals[idx++]);
90  } else if (directOut) {
91  Value mem;
93  mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
94  lv);
95  else if (kind == SparseTensorFieldKind::CrdMemRef)
96  mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
97  lv);
98  else
99  mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
100  toVals.push_back(mem);
101  } else {
102  ShapedType rtp = cast<ShapedType>(t);
103  rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
104  inputs.push_back(extraVals[extra++]);
105  retTypes.push_back(rtp);
106  cntTypes.push_back(builder.getIndexType());
107  }
108  }
109  return true;
110  });
111 
112  if (isIn) {
113  // Assemble multiple inputs into a single sparse tensor.
114  auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
115  toVals.push_back(a.getResult());
116  } else if (!directOut) {
117  // Disassemble a single sparse input into multiple outputs.
118  // Note that this includes the counters, which are dropped.
119  unsigned len = retTypes.size();
120  retTypes.append(cntTypes);
121  auto d =
122  builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
123  for (unsigned i = 0; i < len; i++)
124  toVals.push_back(d.getResult(i));
125  }
126  }
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // Rewriting rules.
131 //===----------------------------------------------------------------------===//
132 
133 namespace {
134 
135 // A rewriting rules that converts public entry methods that use sparse tensors
136 // as input parameters and/or output return values into wrapper methods that
137 // [dis]assemble the individual tensors that constitute the actual storage used
138 // externally into MLIR sparse tensors before calling the original method.
139 //
140 // In particular, each sparse tensor input
141 //
142 // void foo(..., t, ...) { }
143 //
144 // makes the original foo() internal and adds the following wrapper method
145 //
146 // void foo(..., t1..tn, ...) {
147 // t = assemble t1..tn
148 // _internal_foo(..., t, ...)
149 // }
150 //
151 // and likewise, each output tensor
152 //
153 // ... T ... bar(...) { return ..., t, ...; }
154 //
155 // makes the original bar() internal and adds the following wrapper method
156 //
157 // ... T1..TN ... bar(..., t1'..tn') {
158 // ..., t, ... = _internal_bar(...)
159 // t1..tn = disassemble t, t1'..tn'
160 // return ..., t1..tn, ...
161 // }
162 //
163 // (with a direct-out variant without the disassemble).
164 //
165 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
167 
168  SparseFuncAssembler(MLIRContext *context, bool dO)
169  : OpRewritePattern(context), directOut(dO) {}
170 
171  LogicalResult matchAndRewrite(func::FuncOp funcOp,
172  PatternRewriter &rewriter) const override {
173  // Only rewrite public entry methods.
174  if (funcOp.isPrivate())
175  return failure();
176 
177  // Translate sparse tensor types to external types.
178  SmallVector<Type> inputTypes;
179  SmallVector<Type> outputTypes;
180  SmallVector<Type> extraTypes;
181  bool hasAnnotation = false;
182  convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
183  false);
184  convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
185  directOut);
186 
187  // Only sparse inputs or outputs need a wrapper method.
188  if (!hasAnnotation)
189  return failure();
190 
191  // Modify the original method into an internal, private method.
192  auto orgName = funcOp.getName();
193  std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
194  funcOp.setName(wrapper);
195  funcOp.setPrivate();
196 
197  // Start the new public wrapper method with original name.
198  Location loc = funcOp.getLoc();
199  ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
200  MLIRContext *context = modOp.getContext();
201  OpBuilder moduleBuilder(modOp.getBodyRegion());
202  unsigned extra = inputTypes.size();
203  inputTypes.append(extraTypes);
204  auto func = moduleBuilder.create<func::FuncOp>(
205  loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
206  func.setPublic();
207 
208  // Construct new wrapper method body.
209  OpBuilder::InsertionGuard insertionGuard(rewriter);
210  Block *body = func.addEntryBlock();
211  rewriter.setInsertionPointToStart(body);
212 
213  // Convert inputs.
214  SmallVector<Value> inputs;
215  convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
216  ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
217 
218  // Call the original, now private method. A subsequent inlining pass can
219  // determine whether cloning the method body in place is worthwhile.
220  auto org = SymbolRefAttr::get(context, wrapper);
221  auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
222  inputs);
223 
224  // Convert outputs and return.
225  SmallVector<Value> outputs;
226  convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
227  body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
228  rewriter.create<func::ReturnOp>(loc, outputs);
229 
230  // Finally, migrate a potential c-interface property.
231  if (funcOp->getAttrOfType<UnitAttr>(
232  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
233  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
234  UnitAttr::get(context));
235  funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
236  }
237  return success();
238  }
239 
240 private:
241  const bool directOut;
242 };
243 
244 } // namespace
245 
246 //===----------------------------------------------------------------------===//
247 // Public method for populating conversion rules.
248 //===----------------------------------------------------------------------===//
249 
251  bool directOut) {
252  patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
253 }
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238