25#define GEN_PASS_DEF_SPARSEASSEMBLER
26#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28#define GEN_PASS_DEF_SPARSIFICATIONPASS
29#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31#define GEN_PASS_DEF_LOWERFOREACHTOSCF
32#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33#define GEN_PASS_DEF_SPARSETENSORCODEGEN
34#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35#define GEN_PASS_DEF_SPARSEVECTORIZATION
36#define GEN_PASS_DEF_SPARSEGPUCODEGEN
37#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
52 SparseAssembler() =
default;
53 SparseAssembler(
const SparseAssembler &pass) =
default;
54 SparseAssembler(
bool dO) { directOut = dO; }
64struct SparseReinterpretMap
66 SparseReinterpretMap() =
default;
67 SparseReinterpretMap(
const SparseReinterpretMap &pass) =
default;
70 loopOrderingStrategy =
options.loopOrderingStrategy;
73 void runOnOperation()
override {
81struct PreSparsificationRewritePass
83 PreSparsificationRewritePass() =
default;
84 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
87 void runOnOperation()
override {
95struct SparsificationPass
97 SparsificationPass() =
default;
98 SparsificationPass(
const SparsificationPass &pass) =
default;
100 parallelization =
options.parallelizationStrategy;
101 sparseEmitStrategy =
options.sparseEmitStrategy;
102 enableRuntimeLibrary =
options.enableRuntimeLibrary;
109 enableRuntimeLibrary);
113 scf::ForOp::getCanonicalizationPatterns(
patterns, ctx);
118struct StageSparseOperationsPass
120 StageSparseOperationsPass() =
default;
121 StageSparseOperationsPass(
const StageSparseOperationsPass &pass) =
default;
130struct LowerSparseOpsToForeachPass
132 LowerSparseOpsToForeachPass() =
default;
133 LowerSparseOpsToForeachPass(
const LowerSparseOpsToForeachPass &pass) =
135 LowerSparseOpsToForeachPass(
bool enableRT,
bool convert) {
136 enableRuntimeLibrary = enableRT;
137 enableConvert = convert;
140 void runOnOperation()
override {
149struct LowerForeachToSCFPass
151 LowerForeachToSCFPass() =
default;
152 LowerForeachToSCFPass(
const LowerForeachToSCFPass &pass) =
default;
154 void runOnOperation()
override {
162struct LowerSparseIterationToSCFPass
164 LowerSparseIterationToSCFPass> {
165 LowerSparseIterationToSCFPass() =
default;
166 LowerSparseIterationToSCFPass(
const LowerSparseIterationToSCFPass &) =
176 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
177 memref::MemRefDialect, scf::SCFDialect,
178 sparse_tensor::SparseTensorDialect>();
179 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
181 target.addLegalOp<UnrealizedConversionCastOp>();
190struct SparseTensorConversionPass
192 SparseTensorConversionPass() =
default;
193 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
201 target.addIllegalDialect<SparseTensorDialect>();
205 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
206 return converter.isSignatureLegal(op.getFunctionType());
208 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
209 return converter.isSignatureLegal(op.getCalleeType());
211 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
212 return converter.isLegal(op.getOperandTypes());
214 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
215 return converter.isLegal(op.getOperandTypes());
217 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
218 return converter.isLegal(op.getSource().getType()) &&
219 converter.isLegal(op.getDest().getType());
221 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
222 [&](tensor::ExpandShapeOp op) {
223 return converter.isLegal(op.getSrc().getType()) &&
224 converter.isLegal(op.getResult().getType());
226 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
227 [&](tensor::CollapseShapeOp op) {
228 return converter.isLegal(op.getSrc().getType()) &&
229 converter.isLegal(op.getResult().getType());
231 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232 [&](bufferization::AllocTensorOp op) {
233 return converter.isLegal(op.getType());
235 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236 [&](bufferization::DeallocTensorOp op) {
237 return converter.isLegal(op.getTensor().getType());
241 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
242 linalg::YieldOp, tensor::ExtractOp,
243 tensor::FromElementsOp>();
245 arith::ArithDialect, bufferization::BufferizationDialect,
246 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
249 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
261struct SparseTensorCodegenPass
263 SparseTensorCodegenPass() =
default;
264 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
265 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
266 createSparseDeallocs = createDeallocs;
267 enableBufferInitialization = enableInit;
270 void runOnOperation()
override {
276 target.addIllegalDialect<SparseTensorDialect>();
278 target.addLegalOp<PushBackOp>();
280 target.addLegalOp<GetStorageSpecifierOp>();
281 target.addLegalOp<SetStorageSpecifierOp>();
282 target.addLegalOp<StorageSpecifierInitOp>();
284 target.addLegalOp<tensor::FromElementsOp>();
289 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
290 return converter.isSignatureLegal(op.getFunctionType());
292 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
293 return converter.isSignatureLegal(op.getCalleeType());
295 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
296 return converter.isLegal(op.getOperandTypes());
298 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
299 [&](bufferization::AllocTensorOp op) {
300 return converter.isLegal(op.getType());
302 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
303 [&](bufferization::DeallocTensorOp op) {
304 return converter.isLegal(op.getTensor().getType());
308 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
310 arith::ArithDialect, bufferization::BufferizationDialect,
311 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
312 target.addLegalOp<UnrealizedConversionCastOp>();
314 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
319 converter,
patterns, createSparseDeallocs, enableBufferInitialization);
326struct SparseBufferRewritePass
328 SparseBufferRewritePass() =
default;
329 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
330 SparseBufferRewritePass(
bool enableInit) {
331 enableBufferInitialization = enableInit;
342struct SparseVectorizationPass
344 SparseVectorizationPass() =
default;
345 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
346 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
348 enableVLAVectorization = vla;
349 enableSIMDIndex32 = sidx32;
352 void runOnOperation()
override {
353 if (vectorLength == 0)
354 return signalPassFailure();
358 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
364struct SparseGPUCodegenPass
366 SparseGPUCodegenPass() =
default;
367 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
368 SparseGPUCodegenPass(
unsigned nT,
bool enableRT) {
370 enableRuntimeLibrary = enableRT;
373 void runOnOperation()
override {
384struct StorageSpecifierToLLVMPass
386 StorageSpecifierToLLVMPass() =
default;
395 target.addIllegalDialect<SparseTensorDialect>();
396 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
397 return converter.isSignatureLegal(op.getFunctionType());
399 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
400 return converter.isSignatureLegal(op.getCalleeType());
402 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
403 return converter.isLegal(op.getOperandTypes());
405 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
407 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
428 return std::make_unique<SparseAssembler>();
432 return std::make_unique<SparseReinterpretMap>();
439 return std::make_unique<SparseReinterpretMap>(
options);
446 options.loopOrderingStrategy = strategy;
447 return std::make_unique<SparseReinterpretMap>(
options);
451 return std::make_unique<PreSparsificationRewritePass>();
455 return std::make_unique<SparsificationPass>();
460 return std::make_unique<SparsificationPass>(
options);
464 return std::make_unique<StageSparseOperationsPass>();
468 return std::make_unique<LowerSparseOpsToForeachPass>();
473 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
477 return std::make_unique<LowerForeachToSCFPass>();
481 return std::make_unique<LowerSparseIterationToSCFPass>();
485 return std::make_unique<SparseTensorConversionPass>();
489 return std::make_unique<SparseTensorCodegenPass>();
494 bool enableBufferInitialization) {
495 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
496 enableBufferInitialization);
500 return std::make_unique<SparseBufferRewritePass>();
505 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
509 return std::make_unique<SparseVectorizationPass>();
514 bool enableVLAVectorization,
515 bool enableSIMDIndex32) {
516 return std::make_unique<SparseVectorizationPass>(
517 vectorLength, enableVLAVectorization, enableSIMDIndex32);
521 return std::make_unique<SparseGPUCodegenPass>();
526 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
530 return std::make_unique<StorageSpecifierToLLVMPass>();
static llvm::ManagedStatic< PassManagerOptions > options
func::FuncOp getOperation()
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
Type converter for iter_space and iterator.
Options for the Sparsification pass.