MLIR  19.0.0git
SparseAssembler.cpp
Go to the documentation of this file.
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 using namespace sparse_tensor;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper methods.
24 //===----------------------------------------------------------------------===//
25 
26 // Convert type range to new types range, with sparse tensors externalized.
28  SmallVectorImpl<Type> *extraTypes, bool directOut) {
29  for (auto type : types) {
30  // All "dense" data passes through unmodified.
31  if (!getSparseTensorEncoding(type)) {
32  convTypes.push_back(type);
33  continue;
34  }
35 
36  // Convert the external representations of the pos/crd/val arrays.
37  const SparseTensorType stt(cast<RankedTensorType>(type));
39  stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
41  Level, LevelType) {
45  auto rtp = cast<ShapedType>(t);
46  if (!directOut) {
47  rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
48  if (extraTypes)
49  extraTypes->push_back(rtp);
50  }
51  convTypes.push_back(rtp);
52  }
53  return true;
54  });
55  }
56 }
57 
58 // Convert input and output values to [dis]assemble ops for sparse tensors.
59 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
60  ValueRange fromVals, ValueRange extraVals,
61  SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62  bool directOut) {
63  unsigned idx = 0;
64  for (auto type : types) {
65  // All "dense" data passes through unmodified.
66  if (!getSparseTensorEncoding(type)) {
67  toVals.push_back(fromVals[idx++]);
68  continue;
69  }
70  // Handle sparse data.
71  auto rtp = cast<RankedTensorType>(type);
72  const SparseTensorType stt(rtp);
73  SmallVector<Value> inputs;
74  SmallVector<Type> retTypes;
75  SmallVector<Type> cntTypes;
76  if (!isIn)
77  inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
78 
79  // Collect the external representations of the pos/crd/val arrays.
82  Level lv, LevelType) {
86  if (isIn) {
87  inputs.push_back(fromVals[idx++]);
88  } else if (directOut) {
89  Value mem;
91  mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
92  lv);
93  else if (kind == SparseTensorFieldKind::CrdMemRef)
94  mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
95  lv);
96  else
97  mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
98  toVals.push_back(mem);
99  } else {
100  ShapedType rtp = cast<ShapedType>(t);
101  rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
102  inputs.push_back(extraVals[extra++]);
103  retTypes.push_back(rtp);
104  cntTypes.push_back(builder.getIndexType());
105  }
106  }
107  return true;
108  });
109 
110  if (isIn) {
111  // Assemble multiple inputs into a single sparse tensor.
112  auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
113  toVals.push_back(a.getResult());
114  } else if (!directOut) {
115  // Disassemble a single sparse input into multiple outputs.
116  // Note that this includes the counters, which are dropped.
117  unsigned len = retTypes.size();
118  retTypes.append(cntTypes);
119  auto d =
120  builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
121  for (unsigned i = 0; i < len; i++)
122  toVals.push_back(d.getResult(i));
123  }
124  }
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // Rewriting rules.
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 
133 // A rewriting rules that converts public entry methods that use sparse tensors
134 // as input parameters and/or output return values into wrapper methods that
135 // [dis]assemble the individual tensors that constitute the actual storage used
136 // externally into MLIR sparse tensors before calling the original method.
137 //
138 // In particular, each sparse tensor input
139 //
140 // void foo(..., t, ...) { }
141 //
142 // makes the original foo() internal and adds the following wrapper method
143 //
144 // void foo(..., t1..tn, ...) {
145 // t = assemble t1..tn
146 // _internal_foo(..., t, ...)
147 // }
148 //
149 // and likewise, each output tensor
150 //
151 // ... T ... bar(...) { return ..., t, ...; }
152 //
153 // makes the original bar() internal and adds the following wrapper method
154 //
155 // ... T1..TN ... bar(..., t1'..tn') {
156 // ..., t, ... = _internal_bar(...)
157 // t1..tn = disassemble t, t1'..tn'
158 // return ..., t1..tn, ...
159 // }
160 //
161 // (with a direct-out variant without the disassemble).
162 //
163 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
165 
166  SparseFuncAssembler(MLIRContext *context, bool dO)
167  : OpRewritePattern(context), directOut(dO) {}
168 
169  LogicalResult matchAndRewrite(func::FuncOp funcOp,
170  PatternRewriter &rewriter) const override {
171  // Only rewrite public entry methods.
172  if (funcOp.isPrivate())
173  return failure();
174 
175  // Translate sparse tensor types to external types.
176  SmallVector<Type> inputTypes;
177  SmallVector<Type> outputTypes;
178  SmallVector<Type> extraTypes;
179  convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
180  convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
181 
182  // Only sparse inputs or outputs need a wrapper method.
183  if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
184  outputTypes.size() == funcOp.getResultTypes().size())
185  return failure();
186 
187  // Modify the original method into an internal, private method.
188  auto orgName = funcOp.getName();
189  std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
190  funcOp.setName(wrapper);
191  funcOp.setPrivate();
192 
193  // Start the new public wrapper method with original name.
194  Location loc = funcOp.getLoc();
195  ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
196  MLIRContext *context = modOp.getContext();
197  OpBuilder moduleBuilder(modOp.getBodyRegion());
198  unsigned extra = inputTypes.size();
199  inputTypes.append(extraTypes);
200  auto func = moduleBuilder.create<func::FuncOp>(
201  loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
202  func.setPublic();
203 
204  // Construct new wrapper method body.
205  OpBuilder::InsertionGuard insertionGuard(rewriter);
206  Block *body = func.addEntryBlock();
207  rewriter.setInsertionPointToStart(body);
208 
209  // Convert inputs.
210  SmallVector<Value> inputs;
211  convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
212  ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
213 
214  // Call the original, now private method. A subsequent inlining pass can
215  // determine whether cloning the method body in place is worthwhile.
216  auto org = SymbolRefAttr::get(context, wrapper);
217  auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
218  inputs);
219 
220  // Convert outputs and return.
221  SmallVector<Value> outputs;
222  convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
223  body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
224  rewriter.create<func::ReturnOp>(loc, outputs);
225 
226  // Finally, migrate a potential c-interface property.
227  if (funcOp->getAttrOfType<UnitAttr>(
228  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
229  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
230  UnitAttr::get(context));
231  funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
232  }
233  return success();
234  }
235 
236 private:
237  const bool directOut;
238 };
239 
240 } // namespace
241 
242 //===----------------------------------------------------------------------===//
243 // Public method for populating conversion rules.
244 //===----------------------------------------------------------------------===//
245 
247  bool directOut) {
248  patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
249 }
static void convTypes(TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:84
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238