25#define GEN_PASS_DEF_SPARSEASSEMBLER
26#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28#define GEN_PASS_DEF_SPARSIFICATIONPASS
29#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31#define GEN_PASS_DEF_LOWERFOREACHTOSCF
32#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33#define GEN_PASS_DEF_SPARSETENSORCODEGEN
34#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35#define GEN_PASS_DEF_SPARSEVECTORIZATION
36#define GEN_PASS_DEF_SPARSEGPUCODEGEN
37#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
51struct SparseAssembler :
public impl::SparseAssemblerBase<SparseAssembler> {
52 SparseAssembler() =
default;
53 SparseAssembler(
const SparseAssembler &pass) =
default;
54 SparseAssembler(
bool dO) { directOut = dO; }
56 void runOnOperation()
override {
64struct SparseReinterpretMap
65 :
public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66 SparseReinterpretMap() =
default;
67 SparseReinterpretMap(
const SparseReinterpretMap &pass) =
default;
68 SparseReinterpretMap(
const SparseReinterpretMapOptions &
options) {
70 loopOrderingStrategy =
options.loopOrderingStrategy;
73 void runOnOperation()
override {
81struct PreSparsificationRewritePass
82 :
public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
83 PreSparsificationRewritePass() =
default;
84 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
87 void runOnOperation()
override {
95struct SparsificationPass
96 :
public impl::SparsificationPassBase<SparsificationPass> {
97 SparsificationPass() =
default;
98 SparsificationPass(
const SparsificationPass &pass) =
default;
99 SparsificationPass(
const SparsificationOptions &
options) {
100 parallelization =
options.parallelizationStrategy;
101 sparseEmitStrategy =
options.sparseEmitStrategy;
102 enableRuntimeLibrary =
options.enableRuntimeLibrary;
105 void runOnOperation()
override {
108 SparsificationOptions
options(parallelization, sparseEmitStrategy,
109 enableRuntimeLibrary);
113 scf::ForOp::getCanonicalizationPatterns(
patterns, ctx);
118struct StageSparseOperationsPass
119 :
public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
120 StageSparseOperationsPass() =
default;
121 StageSparseOperationsPass(
const StageSparseOperationsPass &pass) =
default;
122 void runOnOperation()
override {
130struct LowerSparseOpsToForeachPass
131 :
public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
132 LowerSparseOpsToForeachPass() =
default;
133 LowerSparseOpsToForeachPass(
const LowerSparseOpsToForeachPass &pass) =
135 LowerSparseOpsToForeachPass(
bool enableRT,
bool convert) {
136 enableRuntimeLibrary = enableRT;
137 enableConvert = convert;
140 void runOnOperation()
override {
149struct LowerForeachToSCFPass
150 :
public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
151 LowerForeachToSCFPass() =
default;
152 LowerForeachToSCFPass(
const LowerForeachToSCFPass &pass) =
default;
154 void runOnOperation()
override {
162struct LowerSparseIterationToSCFPass
163 :
public impl::LowerSparseIterationToSCFBase<
164 LowerSparseIterationToSCFPass> {
165 LowerSparseIterationToSCFPass() =
default;
166 LowerSparseIterationToSCFPass(
const LowerSparseIterationToSCFPass &) =
169 void runOnOperation()
override {
172 SparseIterationTypeConverter converter;
173 ConversionTarget
target(*ctx);
176 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
177 memref::MemRefDialect, scf::SCFDialect,
178 sparse_tensor::SparseTensorDialect>();
179 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
181 target.addLegalOp<UnrealizedConversionCastOp>();
184 if (
failed(applyPartialConversion(getOperation(),
target,
190struct SparseTensorConversionPass
191 :
public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
192 SparseTensorConversionPass() =
default;
193 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
195 void runOnOperation()
override {
198 SparseTensorTypeToPtrConverter converter;
199 ConversionTarget
target(*ctx);
201 target.addIllegalDialect<SparseTensorDialect>();
205 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
206 return converter.isSignatureLegal(op.getFunctionType());
208 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
209 return converter.isSignatureLegal(op.getCalleeType());
211 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
212 return converter.isLegal(op.getOperandTypes());
214 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
215 return converter.isLegal(op.getOperandTypes());
217 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
218 return converter.isLegal(op.getSource().getType()) &&
219 converter.isLegal(op.getDest().getType());
221 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
222 [&](tensor::ExpandShapeOp op) {
223 return converter.isLegal(op.getSrc().getType()) &&
224 converter.isLegal(op.getResult().getType());
226 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
227 [&](tensor::CollapseShapeOp op) {
228 return converter.isLegal(op.getSrc().getType()) &&
229 converter.isLegal(op.getResult().getType());
231 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232 [&](bufferization::AllocTensorOp op) {
233 return converter.isLegal(op.getType());
235 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236 [&](bufferization::DeallocTensorOp op) {
237 return converter.isLegal(op.getTensor().getType());
241 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
242 linalg::YieldOp, tensor::ExtractOp,
243 tensor::FromElementsOp>();
245 arith::ArithDialect, bufferization::BufferizationDialect,
246 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
249 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
255 if (
failed(applyPartialConversion(getOperation(),
target,
261struct SparseTensorCodegenPass
262 :
public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
263 SparseTensorCodegenPass() =
default;
264 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
265 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
266 createSparseDeallocs = createDeallocs;
267 enableBufferInitialization = enableInit;
270 void runOnOperation()
override {
273 SparseTensorTypeToBufferConverter converter;
274 ConversionTarget
target(*ctx);
276 target.addIllegalDialect<SparseTensorDialect>();
277 target.addLegalOp<SortOp>();
278 target.addLegalOp<PushBackOp>();
280 target.addLegalOp<GetStorageSpecifierOp>();
281 target.addLegalOp<SetStorageSpecifierOp>();
282 target.addLegalOp<StorageSpecifierInitOp>();
284 target.addLegalOp<tensor::FromElementsOp>();
289 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
290 return converter.isSignatureLegal(op.getFunctionType());
292 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
293 return converter.isSignatureLegal(op.getCalleeType());
295 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
296 return converter.isLegal(op.getOperandTypes());
298 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
299 [&](bufferization::AllocTensorOp op) {
300 return converter.isLegal(op.getType());
302 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
303 [&](bufferization::DeallocTensorOp op) {
304 return converter.isLegal(op.getTensor().getType());
308 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
310 arith::ArithDialect, bufferization::BufferizationDialect,
311 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
312 target.addLegalOp<UnrealizedConversionCastOp>();
314 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
319 converter,
patterns, createSparseDeallocs, enableBufferInitialization);
320 if (
failed(applyPartialConversion(getOperation(),
target,
326struct SparseBufferRewritePass
327 :
public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
328 SparseBufferRewritePass() =
default;
329 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
330 SparseBufferRewritePass(
bool enableInit) {
331 enableBufferInitialization = enableInit;
334 void runOnOperation()
override {
342struct SparseVectorizationPass
343 :
public impl::SparseVectorizationBase<SparseVectorizationPass> {
344 SparseVectorizationPass() =
default;
345 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
346 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
348 enableVLAVectorization = vla;
349 enableSIMDIndex32 = sidx32;
352 void runOnOperation()
override {
353 if (vectorLength == 0)
354 return signalPassFailure();
358 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
364struct SparseGPUCodegenPass
365 :
public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
366 SparseGPUCodegenPass() =
default;
367 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
368 SparseGPUCodegenPass(
unsigned nT,
bool enableRT) {
370 enableRuntimeLibrary = enableRT;
373 void runOnOperation()
override {
384struct StorageSpecifierToLLVMPass
385 :
public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
386 StorageSpecifierToLLVMPass() =
default;
388 void runOnOperation()
override {
390 ConversionTarget
target(*ctx);
392 StorageSpecifierToLLVMTypeConverter converter;
395 target.addIllegalDialect<SparseTensorDialect>();
396 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
397 return converter.isSignatureLegal(op.getFunctionType());
399 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
400 return converter.isSignatureLegal(op.getCalleeType());
402 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
403 return converter.isLegal(op.getOperandTypes());
405 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
407 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
415 if (
failed(applyPartialConversion(getOperation(),
target,
428 return std::make_unique<SparseAssembler>();
432 return std::make_unique<SparseReinterpretMap>();
437 SparseReinterpretMapOptions
options;
439 return std::make_unique<SparseReinterpretMap>(
options);
444 SparseReinterpretMapOptions
options;
446 options.loopOrderingStrategy = strategy;
447 return std::make_unique<SparseReinterpretMap>(
options);
451 return std::make_unique<PreSparsificationRewritePass>();
455 return std::make_unique<SparsificationPass>();
460 return std::make_unique<SparsificationPass>(
options);
464 return std::make_unique<StageSparseOperationsPass>();
468 return std::make_unique<LowerSparseOpsToForeachPass>();
473 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
477 return std::make_unique<LowerForeachToSCFPass>();
481 return std::make_unique<LowerSparseIterationToSCFPass>();
485 return std::make_unique<SparseTensorConversionPass>();
489 return std::make_unique<SparseTensorCodegenPass>();
494 bool enableBufferInitialization) {
495 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
496 enableBufferInitialization);
500 return std::make_unique<SparseBufferRewritePass>();
505 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
509 return std::make_unique<SparseVectorizationPass>();
514 bool enableVLAVectorization,
515 bool enableSIMDIndex32) {
516 return std::make_unique<SparseVectorizationPass>(
517 vectorLength, enableVLAVectorization, enableSIMDIndex32);
521 return std::make_unique<SparseGPUCodegenPass>();
526 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
530 return std::make_unique<StorageSpecifierToLLVMPass>();
static llvm::ManagedStatic< PassManagerOptions > options
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
Options for the Sparsification pass.