MLIR  22.0.0git
SparseAssembler.cpp
Go to the documentation of this file.
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
15 #include "llvm/Support/FormatVariadic.h"
16 
17 using namespace mlir;
18 using namespace sparse_tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // Helper methods.
22 //===----------------------------------------------------------------------===//
23 
24 // Convert type range to new types range, with sparse tensors externalized.
25 static void convTypes(bool &hasAnnotation, TypeRange types,
27  SmallVectorImpl<Type> *extraTypes, bool directOut) {
28  for (auto type : types) {
29  // All "dense" data passes through unmodified.
30  if (!getSparseTensorEncoding(type)) {
31  convTypes.push_back(type);
32  continue;
33  }
34  hasAnnotation = true;
35 
36  // Convert the external representations of the pos/crd/val arrays.
37  const SparseTensorType stt(cast<RankedTensorType>(type));
39  stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
41  Level, LevelType) {
45  auto rtp = cast<ShapedType>(t);
46  if (!directOut) {
47  rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
48  if (extraTypes)
49  extraTypes->push_back(rtp);
50  }
51  convTypes.push_back(rtp);
52  }
53  return true;
54  });
55  }
56 }
57 
58 // Convert input and output values to [dis]assemble ops for sparse tensors.
59 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
60  ValueRange fromVals, ValueRange extraVals,
61  SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62  bool directOut) {
63  unsigned idx = 0;
64  for (auto type : types) {
65  // All "dense" data passes through unmodified.
66  if (!getSparseTensorEncoding(type)) {
67  toVals.push_back(fromVals[idx++]);
68  continue;
69  }
70  // Handle sparse data.
71  auto rtp = cast<RankedTensorType>(type);
72  const SparseTensorType stt(rtp);
73  SmallVector<Value> inputs;
74  SmallVector<Type> retTypes;
75  SmallVector<Type> cntTypes;
76  if (!isIn)
77  inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
78 
79  // Collect the external representations of the pos/crd/val arrays.
82  Level lv, LevelType) {
86  if (isIn) {
87  inputs.push_back(fromVals[idx++]);
88  } else if (directOut) {
89  Value mem;
91  mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0],
92  lv);
94  mem = sparse_tensor::ToCoordinatesOp::create(builder, loc,
95  inputs[0], lv);
96  else
97  mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]);
98  toVals.push_back(mem);
99  } else {
100  ShapedType rtp = cast<ShapedType>(t);
101  rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
102  inputs.push_back(extraVals[extra++]);
103  retTypes.push_back(rtp);
104  cntTypes.push_back(builder.getIndexType());
105  }
106  }
107  return true;
108  });
109 
110  if (isIn) {
111  // Assemble multiple inputs into a single sparse tensor.
112  auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs);
113  toVals.push_back(a.getResult());
114  } else if (!directOut) {
115  // Disassemble a single sparse input into multiple outputs.
116  // Note that this includes the counters, which are dropped.
117  unsigned len = retTypes.size();
118  retTypes.append(cntTypes);
119  auto d =
120  sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs);
121  for (unsigned i = 0; i < len; i++)
122  toVals.push_back(d.getResult(i));
123  }
124  }
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // Rewriting rules.
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 
133 // A rewriting rules that converts public entry methods that use sparse tensors
134 // as input parameters and/or output return values into wrapper methods that
135 // [dis]assemble the individual tensors that constitute the actual storage used
136 // externally into MLIR sparse tensors before calling the original method.
137 //
138 // In particular, each sparse tensor input
139 //
140 // void foo(..., t, ...) { }
141 //
142 // makes the original foo() internal and adds the following wrapper method
143 //
144 // void foo(..., t1..tn, ...) {
145 // t = assemble t1..tn
146 // _internal_foo(..., t, ...)
147 // }
148 //
149 // and likewise, each output tensor
150 //
151 // ... T ... bar(...) { return ..., t, ...; }
152 //
153 // makes the original bar() internal and adds the following wrapper method
154 //
155 // ... T1..TN ... bar(..., t1'..tn') {
156 // ..., t, ... = _internal_bar(...)
157 // t1..tn = disassemble t, t1'..tn'
158 // return ..., t1..tn, ...
159 // }
160 //
161 // (with a direct-out variant without the disassemble).
162 //
163 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
165 
166  SparseFuncAssembler(MLIRContext *context, bool dO)
167  : OpRewritePattern(context), directOut(dO) {}
168 
169  LogicalResult matchAndRewrite(func::FuncOp funcOp,
170  PatternRewriter &rewriter) const override {
171  // Only rewrite public entry methods.
172  if (funcOp.isPrivate())
173  return failure();
174 
175  // Translate sparse tensor types to external types.
176  SmallVector<Type> inputTypes;
177  SmallVector<Type> outputTypes;
178  SmallVector<Type> extraTypes;
179  bool hasAnnotation = false;
180  convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
181  false);
182  convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
183  directOut);
184 
185  // Only sparse inputs or outputs need a wrapper method.
186  if (!hasAnnotation)
187  return failure();
188 
189  // Modify the original method into an internal, private method.
190  auto orgName = funcOp.getName();
191  std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
192  funcOp.setName(wrapper);
193  funcOp.setPrivate();
194 
195  // Start the new public wrapper method with original name.
196  Location loc = funcOp.getLoc();
197  ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
198  MLIRContext *context = modOp.getContext();
199  OpBuilder moduleBuilder(modOp.getBodyRegion());
200  unsigned extra = inputTypes.size();
201  inputTypes.append(extraTypes);
202  auto func = func::FuncOp::create(
203  moduleBuilder, loc, orgName,
204  FunctionType::get(context, inputTypes, outputTypes));
205  func.setPublic();
206 
207  // Construct new wrapper method body.
208  OpBuilder::InsertionGuard insertionGuard(rewriter);
209  Block *body = func.addEntryBlock();
210  rewriter.setInsertionPointToStart(body);
211 
212  // Convert inputs.
213  SmallVector<Value> inputs;
214  convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
215  ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
216 
217  // Call the original, now private method. A subsequent inlining pass can
218  // determine whether cloning the method body in place is worthwhile.
219  auto org = SymbolRefAttr::get(context, wrapper);
220  auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(),
221  org, inputs);
222 
223  // Convert outputs and return.
224  SmallVector<Value> outputs;
225  convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
226  body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
227  func::ReturnOp::create(rewriter, loc, outputs);
228 
229  // Finally, migrate a potential c-interface property.
230  if (funcOp->getAttrOfType<UnitAttr>(
231  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
232  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
233  UnitAttr::get(context));
234  funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
235  }
236  return success();
237  }
238 
239 private:
240  const bool directOut;
241 };
242 
243 } // namespace
244 
245 //===----------------------------------------------------------------------===//
246 // Public method for populating conversion rules.
247 //===----------------------------------------------------------------------===//
248 
250  bool directOut) {
251  patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
252 }
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238