17 #include "llvm/Support/FormatVariadic.h"
20 using namespace sparse_tensor;
29 for (
auto type : types) {
45 auto rtp = cast<ShapedType>(t);
49 extraTypes->push_back(rtp);
64 for (
auto type : types) {
67 toVals.push_back(fromVals[idx++]);
71 auto rtp = cast<RankedTensorType>(type);
77 inputs.push_back(fromVals[idx++]);
87 inputs.push_back(fromVals[idx++]);
88 }
else if (directOut) {
91 mem = builder.
create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
94 mem = builder.
create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
97 mem = builder.
create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
98 toVals.push_back(mem);
100 ShapedType rtp = cast<ShapedType>(t);
102 inputs.push_back(extraVals[extra++]);
103 retTypes.push_back(rtp);
112 auto a = builder.
create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
113 toVals.push_back(a.getResult());
114 }
else if (!directOut) {
117 unsigned len = retTypes.size();
118 retTypes.append(cntTypes);
120 builder.
create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
121 for (
unsigned i = 0; i < len; i++)
122 toVals.push_back(d.getResult(i));
172 if (funcOp.isPrivate())
179 convTypes(funcOp.getArgumentTypes(), inputTypes,
nullptr,
false);
180 convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
183 if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
184 outputTypes.size() == funcOp.getResultTypes().size())
188 auto orgName = funcOp.getName();
189 std::string wrapper = llvm::formatv(
"_internal_{0}", orgName).str();
190 funcOp.setName(wrapper);
195 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
197 OpBuilder moduleBuilder(modOp.getBodyRegion());
198 unsigned extra = inputTypes.size();
199 inputTypes.append(extraTypes);
200 auto func = moduleBuilder.create<func::FuncOp>(
206 Block *body = func.addEntryBlock();
217 auto call = rewriter.
create<func::CallOp>(loc, funcOp.getResultTypes(), org,
222 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
223 body->
getArguments(), outputs, extra,
false, directOut);
224 rewriter.
create<func::ReturnOp>(loc, outputs);
227 if (funcOp->getAttrOfType<UnitAttr>(
228 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
229 func->
setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
231 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
237 const bool directOut;
248 patterns.
add<SparseFuncAssembler>(patterns.
getContext(), directOut);
static void convTypes(TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)
static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)
Block represents an ordered list of Operations.
BlockArgListType getArguments()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.