15#include "llvm/Support/FormatVariadic.h"
28 for (
auto type : types) {
44 if (kind == SparseTensorFieldKind::PosMemRef ||
45 kind == SparseTensorFieldKind::CrdMemRef ||
46 kind == SparseTensorFieldKind::ValMemRef) {
47 auto rtp = cast<ShapedType>(t);
49 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
51 extraTypes->push_back(rtp);
66 for (
auto type : types) {
71 toVals.push_back(fromVals[idx++]);
75 auto rtp = cast<RankedTensorType>(type);
81 inputs.push_back(fromVals[idx++]);
87 if (kind == SparseTensorFieldKind::PosMemRef ||
88 kind == SparseTensorFieldKind::CrdMemRef ||
89 kind == SparseTensorFieldKind::ValMemRef) {
91 inputs.push_back(fromVals[idx++]);
92 }
else if (directOut) {
94 if (kind == SparseTensorFieldKind::PosMemRef)
95 mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0],
97 else if (kind == SparseTensorFieldKind::CrdMemRef)
98 mem = sparse_tensor::ToCoordinatesOp::create(builder, loc,
101 mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]);
102 toVals.push_back(mem);
104 ShapedType rtp = cast<ShapedType>(t);
105 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
106 inputs.push_back(extraVals[extra++]);
107 retTypes.push_back(rtp);
116 auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs);
117 toVals.push_back(a.getResult());
118 }
else if (!directOut) {
121 unsigned len = retTypes.size();
122 retTypes.append(cntTypes);
124 sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs);
125 for (
unsigned i = 0; i < len; i++)
126 toVals.push_back(d.getResult(i));
170 SparseFuncAssembler(MLIRContext *context,
bool dO)
171 : OpRewritePattern(context), directOut(dO) {}
173 LogicalResult matchAndRewrite(func::FuncOp funcOp,
174 PatternRewriter &rewriter)
const override {
176 if (funcOp.isPrivate())
180 SmallVector<Type> inputTypes;
181 SmallVector<Type> outputTypes;
182 SmallVector<Type> extraTypes;
183 bool hasAnnotation =
false;
184 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes,
nullptr,
186 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
194 auto orgName = funcOp.getName();
195 std::string wrapper = llvm::formatv(
"_internal_{0}", orgName).str();
197 funcOp.setName(wrapper);
202 Location loc = funcOp.getLoc();
203 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
204 MLIRContext *context = modOp.getContext();
205 OpBuilder moduleBuilder(modOp.getBodyRegion());
206 unsigned extra = inputTypes.size();
207 inputTypes.append(extraTypes);
208 auto func = func::FuncOp::create(
209 moduleBuilder, loc, orgName,
210 FunctionType::get(context, inputTypes, outputTypes));
214 OpBuilder::InsertionGuard insertionGuard(rewriter);
215 Block *body = func.addEntryBlock();
219 SmallVector<Value> inputs;
225 auto org = SymbolRefAttr::get(context, wrapper);
226 auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(),
230 SmallVector<Value> outputs;
231 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
232 body->
getArguments(), outputs, extra,
false, directOut);
233 func::ReturnOp::create(rewriter, loc, outputs);
236 if (funcOp->getAttrOfType<UnitAttr>(
237 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
238 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
239 UnitAttr::get(context));
241 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
248 const bool directOut;
259 patterns.
add<SparseFuncAssembler>(patterns.
getContext(), directOut);
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
BlockArgListType getArguments()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.