25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
51 struct SparseAssembler :
public impl::SparseAssemblerBase<SparseAssembler> {
52 SparseAssembler() =
default;
53 SparseAssembler(
const SparseAssembler &pass) =
default;
54 SparseAssembler(
bool dO) { directOut = dO; }
56 void runOnOperation()
override {
64 struct SparseReinterpretMap
65 :
public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66 SparseReinterpretMap() =
default;
67 SparseReinterpretMap(
const SparseReinterpretMap &pass) =
default;
68 SparseReinterpretMap(
const SparseReinterpretMapOptions &
options) {
72 void runOnOperation()
override {
80 struct PreSparsificationRewritePass
81 :
public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
82 PreSparsificationRewritePass() =
default;
83 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
86 void runOnOperation()
override {
94 struct SparsificationPass
95 :
public impl::SparsificationPassBase<SparsificationPass> {
96 SparsificationPass() =
default;
97 SparsificationPass(
const SparsificationPass &pass) =
default;
99 parallelization =
options.parallelizationStrategy;
100 sparseEmitStrategy =
options.sparseEmitStrategy;
101 enableRuntimeLibrary =
options.enableRuntimeLibrary;
104 void runOnOperation()
override {
108 enableRuntimeLibrary);
112 scf::ForOp::getCanonicalizationPatterns(
patterns, ctx);
117 struct StageSparseOperationsPass
118 :
public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
119 StageSparseOperationsPass() =
default;
120 StageSparseOperationsPass(
const StageSparseOperationsPass &pass) =
default;
121 void runOnOperation()
override {
129 struct LowerSparseOpsToForeachPass
130 :
public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
131 LowerSparseOpsToForeachPass() =
default;
132 LowerSparseOpsToForeachPass(
const LowerSparseOpsToForeachPass &pass) =
134 LowerSparseOpsToForeachPass(
bool enableRT,
bool convert) {
135 enableRuntimeLibrary = enableRT;
136 enableConvert = convert;
139 void runOnOperation()
override {
148 struct LowerForeachToSCFPass
149 :
public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
150 LowerForeachToSCFPass() =
default;
151 LowerForeachToSCFPass(
const LowerForeachToSCFPass &pass) =
default;
153 void runOnOperation()
override {
161 struct LowerSparseIterationToSCFPass
162 :
public impl::LowerSparseIterationToSCFBase<
163 LowerSparseIterationToSCFPass> {
164 LowerSparseIterationToSCFPass() =
default;
165 LowerSparseIterationToSCFPass(
const LowerSparseIterationToSCFPass &) =
168 void runOnOperation()
override {
175 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
176 memref::MemRefDialect, scf::SCFDialect,
177 sparse_tensor::SparseTensorDialect>();
178 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
180 target.addLegalOp<UnrealizedConversionCastOp>();
189 struct SparseTensorConversionPass
190 :
public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
191 SparseTensorConversionPass() =
default;
192 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
194 void runOnOperation()
override {
200 target.addIllegalDialect<SparseTensorDialect>();
204 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
207 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
210 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
213 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
216 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
220 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
221 [&](tensor::ExpandShapeOp op) {
225 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
226 [&](tensor::CollapseShapeOp op) {
230 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
231 [&](bufferization::AllocTensorOp op) {
234 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
235 [&](bufferization::DeallocTensorOp op) {
240 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
241 linalg::YieldOp, tensor::ExtractOp,
242 tensor::FromElementsOp>();
243 target.addLegalDialect<
244 arith::ArithDialect, bufferization::BufferizationDialect,
245 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
248 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
260 struct SparseTensorCodegenPass
261 :
public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
262 SparseTensorCodegenPass() =
default;
263 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
264 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
265 createSparseDeallocs = createDeallocs;
266 enableBufferInitialization = enableInit;
269 void runOnOperation()
override {
275 target.addIllegalDialect<SparseTensorDialect>();
276 target.addLegalOp<SortOp>();
277 target.addLegalOp<PushBackOp>();
279 target.addLegalOp<GetStorageSpecifierOp>();
280 target.addLegalOp<SetStorageSpecifierOp>();
281 target.addLegalOp<StorageSpecifierInitOp>();
283 target.addLegalOp<tensor::FromElementsOp>();
288 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
291 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
294 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
297 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
298 [&](bufferization::AllocTensorOp op) {
301 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
302 [&](bufferization::DeallocTensorOp op) {
307 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
308 target.addLegalDialect<
309 arith::ArithDialect, bufferization::BufferizationDialect,
310 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
311 target.addLegalOp<UnrealizedConversionCastOp>();
313 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
325 struct SparseBufferRewritePass
326 :
public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
327 SparseBufferRewritePass() =
default;
328 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
329 SparseBufferRewritePass(
bool enableInit) {
330 enableBufferInitialization = enableInit;
333 void runOnOperation()
override {
341 struct SparseVectorizationPass
342 :
public impl::SparseVectorizationBase<SparseVectorizationPass> {
343 SparseVectorizationPass() =
default;
344 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
345 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
347 enableVLAVectorization = vla;
348 enableSIMDIndex32 = sidx32;
351 void runOnOperation()
override {
352 if (vectorLength == 0)
353 return signalPassFailure();
357 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
363 struct SparseGPUCodegenPass
364 :
public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
365 SparseGPUCodegenPass() =
default;
366 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
367 SparseGPUCodegenPass(
unsigned nT,
bool enableRT) {
369 enableRuntimeLibrary = enableRT;
372 void runOnOperation()
override {
383 struct StorageSpecifierToLLVMPass
384 :
public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
385 StorageSpecifierToLLVMPass() =
default;
387 void runOnOperation()
override {
394 target.addIllegalDialect<SparseTensorDialect>();
395 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
398 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
401 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
404 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
406 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
427 return std::make_unique<SparseAssembler>();
431 return std::make_unique<SparseReinterpretMap>();
434 std::unique_ptr<Pass>
436 SparseReinterpretMapOptions
options;
438 return std::make_unique<SparseReinterpretMap>(
options);
442 return std::make_unique<PreSparsificationRewritePass>();
446 return std::make_unique<SparsificationPass>();
449 std::unique_ptr<Pass>
451 return std::make_unique<SparsificationPass>(
options);
455 return std::make_unique<StageSparseOperationsPass>();
459 return std::make_unique<LowerSparseOpsToForeachPass>();
462 std::unique_ptr<Pass>
464 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
468 return std::make_unique<LowerForeachToSCFPass>();
472 return std::make_unique<LowerSparseIterationToSCFPass>();
476 return std::make_unique<SparseTensorConversionPass>();
480 return std::make_unique<SparseTensorCodegenPass>();
483 std::unique_ptr<Pass>
485 bool enableBufferInitialization) {
486 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
487 enableBufferInitialization);
491 return std::make_unique<SparseBufferRewritePass>();
494 std::unique_ptr<Pass>
496 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
500 return std::make_unique<SparseVectorizationPass>();
503 std::unique_ptr<Pass>
505 bool enableVLAVectorization,
506 bool enableSIMDIndex32) {
507 return std::make_unique<SparseVectorizationPass>(
508 vectorLength, enableVLAVectorization, enableSIMDIndex32);
512 return std::make_unique<SparseGPUCodegenPass>();
517 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
521 return std::make_unique<StorageSpecifierToLLVMPass>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
const FrozenRewritePatternSet & patterns
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
const TypeConverter & converter
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Type converter for iter_space and iterator.
Options for the Sparsification pass.