15#include "llvm/Support/FormatVariadic.h"
28 for (
auto type : types) {
42 if (kind == SparseTensorFieldKind::PosMemRef ||
43 kind == SparseTensorFieldKind::CrdMemRef ||
44 kind == SparseTensorFieldKind::ValMemRef) {
45 auto rtp = cast<ShapedType>(t);
47 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
49 extraTypes->push_back(rtp);
64 for (
auto type : types) {
67 toVals.push_back(fromVals[idx++]);
71 auto rtp = cast<RankedTensorType>(type);
77 inputs.push_back(fromVals[idx++]);
83 if (kind == SparseTensorFieldKind::PosMemRef ||
84 kind == SparseTensorFieldKind::CrdMemRef ||
85 kind == SparseTensorFieldKind::ValMemRef) {
87 inputs.push_back(fromVals[idx++]);
88 }
else if (directOut) {
90 if (kind == SparseTensorFieldKind::PosMemRef)
91 mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0],
93 else if (kind == SparseTensorFieldKind::CrdMemRef)
94 mem = sparse_tensor::ToCoordinatesOp::create(builder, loc,
97 mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]);
98 toVals.push_back(mem);
100 ShapedType rtp = cast<ShapedType>(t);
101 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
102 inputs.push_back(extraVals[extra++]);
103 retTypes.push_back(rtp);
112 auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs);
113 toVals.push_back(a.getResult());
114 }
else if (!directOut) {
117 unsigned len = retTypes.size();
118 retTypes.append(cntTypes);
120 sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs);
121 for (
unsigned i = 0; i < len; i++)
122 toVals.push_back(d.getResult(i));
166 SparseFuncAssembler(MLIRContext *context,
bool dO)
167 : OpRewritePattern(context), directOut(dO) {}
169 LogicalResult matchAndRewrite(func::FuncOp funcOp,
170 PatternRewriter &rewriter)
const override {
172 if (funcOp.isPrivate())
176 SmallVector<Type> inputTypes;
177 SmallVector<Type> outputTypes;
178 SmallVector<Type> extraTypes;
179 bool hasAnnotation =
false;
180 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes,
nullptr,
182 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
190 auto orgName = funcOp.getName();
191 std::string wrapper = llvm::formatv(
"_internal_{0}", orgName).str();
192 funcOp.setName(wrapper);
196 Location loc = funcOp.getLoc();
197 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
198 MLIRContext *context = modOp.getContext();
199 OpBuilder moduleBuilder(modOp.getBodyRegion());
200 unsigned extra = inputTypes.size();
201 inputTypes.append(extraTypes);
202 auto func = func::FuncOp::create(
203 moduleBuilder, loc, orgName,
204 FunctionType::get(context, inputTypes, outputTypes));
208 OpBuilder::InsertionGuard insertionGuard(rewriter);
209 Block *body = func.addEntryBlock();
213 SmallVector<Value> inputs;
219 auto org = SymbolRefAttr::get(context, wrapper);
220 auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(),
224 SmallVector<Value> outputs;
225 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
226 body->
getArguments(), outputs, extra,
false, directOut);
227 func::ReturnOp::create(rewriter, loc, outputs);
230 if (funcOp->getAttrOfType<UnitAttr>(
231 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
232 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
233 UnitAttr::get(context));
234 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
240 const bool directOut;
static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
BlockArgListType getArguments()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.