25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
51 struct SparseAssembler :
public impl::SparseAssemblerBase<SparseAssembler> {
52 SparseAssembler() =
default;
53 SparseAssembler(
const SparseAssembler &pass) =
default;
54 SparseAssembler(
bool dO) { directOut = dO; }
56 void runOnOperation()
override {
64 struct SparseReinterpretMap
65 :
public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66 SparseReinterpretMap() =
default;
67 SparseReinterpretMap(
const SparseReinterpretMap &pass) =
default;
68 SparseReinterpretMap(
const SparseReinterpretMapOptions &
options) {
70 loopOrderingStrategy =
options.loopOrderingStrategy;
73 void runOnOperation()
override {
81 struct PreSparsificationRewritePass
82 :
public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
83 PreSparsificationRewritePass() =
default;
84 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
87 void runOnOperation()
override {
95 struct SparsificationPass
96 :
public impl::SparsificationPassBase<SparsificationPass> {
97 SparsificationPass() =
default;
98 SparsificationPass(
const SparsificationPass &pass) =
default;
100 parallelization =
options.parallelizationStrategy;
101 sparseEmitStrategy =
options.sparseEmitStrategy;
102 enableRuntimeLibrary =
options.enableRuntimeLibrary;
105 void runOnOperation()
override {
109 enableRuntimeLibrary);
113 scf::ForOp::getCanonicalizationPatterns(
patterns, ctx);
118 struct StageSparseOperationsPass
119 :
public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
120 StageSparseOperationsPass() =
default;
121 StageSparseOperationsPass(
const StageSparseOperationsPass &pass) =
default;
122 void runOnOperation()
override {
130 struct LowerSparseOpsToForeachPass
131 :
public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
132 LowerSparseOpsToForeachPass() =
default;
133 LowerSparseOpsToForeachPass(
const LowerSparseOpsToForeachPass &pass) =
135 LowerSparseOpsToForeachPass(
bool enableRT,
bool convert) {
136 enableRuntimeLibrary = enableRT;
137 enableConvert = convert;
140 void runOnOperation()
override {
149 struct LowerForeachToSCFPass
150 :
public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
151 LowerForeachToSCFPass() =
default;
152 LowerForeachToSCFPass(
const LowerForeachToSCFPass &pass) =
default;
154 void runOnOperation()
override {
162 struct LowerSparseIterationToSCFPass
163 :
public impl::LowerSparseIterationToSCFBase<
164 LowerSparseIterationToSCFPass> {
165 LowerSparseIterationToSCFPass() =
default;
166 LowerSparseIterationToSCFPass(
const LowerSparseIterationToSCFPass &) =
169 void runOnOperation()
override {
176 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
177 memref::MemRefDialect, scf::SCFDialect,
178 sparse_tensor::SparseTensorDialect>();
179 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
181 target.addLegalOp<UnrealizedConversionCastOp>();
190 struct SparseTensorConversionPass
191 :
public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
192 SparseTensorConversionPass() =
default;
193 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
195 void runOnOperation()
override {
201 target.addIllegalDialect<SparseTensorDialect>();
205 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
208 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
211 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
212 return converter.
isLegal(op.getOperandTypes());
214 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
215 return converter.
isLegal(op.getOperandTypes());
217 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
218 return converter.
isLegal(op.getSource().getType()) &&
219 converter.
isLegal(op.getDest().getType());
221 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
222 [&](tensor::ExpandShapeOp op) {
223 return converter.
isLegal(op.getSrc().getType()) &&
224 converter.
isLegal(op.getResult().getType());
226 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
227 [&](tensor::CollapseShapeOp op) {
228 return converter.
isLegal(op.getSrc().getType()) &&
229 converter.
isLegal(op.getResult().getType());
231 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232 [&](bufferization::AllocTensorOp op) {
233 return converter.
isLegal(op.getType());
235 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236 [&](bufferization::DeallocTensorOp op) {
237 return converter.
isLegal(op.getTensor().getType());
241 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
242 linalg::YieldOp, tensor::ExtractOp,
243 tensor::FromElementsOp>();
244 target.addLegalDialect<
245 arith::ArithDialect, bufferization::BufferizationDialect,
246 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
249 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
261 struct SparseTensorCodegenPass
262 :
public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
263 SparseTensorCodegenPass() =
default;
264 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
265 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
266 createSparseDeallocs = createDeallocs;
267 enableBufferInitialization = enableInit;
270 void runOnOperation()
override {
276 target.addIllegalDialect<SparseTensorDialect>();
277 target.addLegalOp<SortOp>();
278 target.addLegalOp<PushBackOp>();
280 target.addLegalOp<GetStorageSpecifierOp>();
281 target.addLegalOp<SetStorageSpecifierOp>();
282 target.addLegalOp<StorageSpecifierInitOp>();
284 target.addLegalOp<tensor::FromElementsOp>();
289 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
292 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
295 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
296 return converter.
isLegal(op.getOperandTypes());
298 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
299 [&](bufferization::AllocTensorOp op) {
300 return converter.
isLegal(op.getType());
302 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
303 [&](bufferization::DeallocTensorOp op) {
304 return converter.
isLegal(op.getTensor().getType());
308 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
309 target.addLegalDialect<
310 arith::ArithDialect, bufferization::BufferizationDialect,
311 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
312 target.addLegalOp<UnrealizedConversionCastOp>();
314 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
319 converter,
patterns, createSparseDeallocs, enableBufferInitialization);
326 struct SparseBufferRewritePass
327 :
public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
328 SparseBufferRewritePass() =
default;
329 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
330 SparseBufferRewritePass(
bool enableInit) {
331 enableBufferInitialization = enableInit;
334 void runOnOperation()
override {
342 struct SparseVectorizationPass
343 :
public impl::SparseVectorizationBase<SparseVectorizationPass> {
344 SparseVectorizationPass() =
default;
345 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
346 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
348 enableVLAVectorization = vla;
349 enableSIMDIndex32 = sidx32;
352 void runOnOperation()
override {
353 if (vectorLength == 0)
354 return signalPassFailure();
358 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
364 struct SparseGPUCodegenPass
365 :
public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
366 SparseGPUCodegenPass() =
default;
367 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
368 SparseGPUCodegenPass(
unsigned nT,
bool enableRT) {
370 enableRuntimeLibrary = enableRT;
373 void runOnOperation()
override {
384 struct StorageSpecifierToLLVMPass
385 :
public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
386 StorageSpecifierToLLVMPass() =
default;
388 void runOnOperation()
override {
395 target.addIllegalDialect<SparseTensorDialect>();
396 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
399 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
402 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
403 return converter.
isLegal(op.getOperandTypes());
405 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
407 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
428 return std::make_unique<SparseAssembler>();
432 return std::make_unique<SparseReinterpretMap>();
435 std::unique_ptr<Pass>
437 SparseReinterpretMapOptions
options;
439 return std::make_unique<SparseReinterpretMap>(
options);
444 SparseReinterpretMapOptions
options;
446 options.loopOrderingStrategy = strategy;
447 return std::make_unique<SparseReinterpretMap>(
options);
451 return std::make_unique<PreSparsificationRewritePass>();
455 return std::make_unique<SparsificationPass>();
458 std::unique_ptr<Pass>
460 return std::make_unique<SparsificationPass>(
options);
464 return std::make_unique<StageSparseOperationsPass>();
468 return std::make_unique<LowerSparseOpsToForeachPass>();
471 std::unique_ptr<Pass>
473 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
477 return std::make_unique<LowerForeachToSCFPass>();
481 return std::make_unique<LowerSparseIterationToSCFPass>();
485 return std::make_unique<SparseTensorConversionPass>();
489 return std::make_unique<SparseTensorCodegenPass>();
492 std::unique_ptr<Pass>
494 bool enableBufferInitialization) {
495 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
496 enableBufferInitialization);
500 return std::make_unique<SparseBufferRewritePass>();
503 std::unique_ptr<Pass>
505 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
509 return std::make_unique<SparseVectorizationPass>();
512 std::unique_ptr<Pass>
514 bool enableVLAVectorization,
515 bool enableSIMDIndex32) {
516 return std::make_unique<SparseVectorizationPass>(
517 vectorLength, enableVLAVectorization, enableSIMDIndex32);
521 return std::make_unique<SparseGPUCodegenPass>();
526 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
530 return std::make_unique<StorageSpecifierToLLVMPass>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Type converter for iter_space and iterator.
Options for the Sparsification pass.