MLIR 22.0.0git
SparseTensorPasses.cpp
Go to the documentation of this file.
1//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
23
24namespace mlir {
25#define GEN_PASS_DEF_SPARSEASSEMBLER
26#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28#define GEN_PASS_DEF_SPARSIFICATIONPASS
29#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31#define GEN_PASS_DEF_LOWERFOREACHTOSCF
32#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33#define GEN_PASS_DEF_SPARSETENSORCODEGEN
34#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35#define GEN_PASS_DEF_SPARSEVECTORIZATION
36#define GEN_PASS_DEF_SPARSEGPUCODEGEN
37#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40} // namespace mlir
42using namespace mlir;
43using namespace mlir::sparse_tensor;
45namespace {
46
47//===----------------------------------------------------------------------===//
48// Passes implementation.
49//===----------------------------------------------------------------------===//
51struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
52 SparseAssembler() = default;
53 SparseAssembler(const SparseAssembler &pass) = default;
54 SparseAssembler(bool dO) { directOut = dO; }
56 void runOnOperation() override {
57 auto *ctx = &getContext();
61 }
62};
63
64struct SparseReinterpretMap
65 : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66 SparseReinterpretMap() = default;
67 SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
68 SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
69 scope = options.scope;
70 loopOrderingStrategy = options.loopOrderingStrategy;
71 }
72
73 void runOnOperation() override {
74 auto *ctx = &getContext();
76 populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
77 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
78 }
79};
80
81struct PreSparsificationRewritePass
82 : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
83 PreSparsificationRewritePass() = default;
84 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
85 default;
86
87 void runOnOperation() override {
88 auto *ctx = &getContext();
91 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
92 }
93};
94
95struct SparsificationPass
96 : public impl::SparsificationPassBase<SparsificationPass> {
97 SparsificationPass() = default;
98 SparsificationPass(const SparsificationPass &pass) = default;
99 SparsificationPass(const SparsificationOptions &options) {
100 parallelization = options.parallelizationStrategy;
101 sparseEmitStrategy = options.sparseEmitStrategy;
102 enableRuntimeLibrary = options.enableRuntimeLibrary;
105 void runOnOperation() override {
106 auto *ctx = &getContext();
107 // Translate strategy flags to strategy options.
108 SparsificationOptions options(parallelization, sparseEmitStrategy,
109 enableRuntimeLibrary);
110 // Apply sparsification and cleanup rewriting.
113 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
115 }
117
118struct StageSparseOperationsPass
119 : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
120 StageSparseOperationsPass() = default;
121 StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
122 void runOnOperation() override {
123 auto *ctx = &getContext();
128};
129
130struct LowerSparseOpsToForeachPass
131 : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
132 LowerSparseOpsToForeachPass() = default;
133 LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
134 default;
135 LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
136 enableRuntimeLibrary = enableRT;
137 enableConvert = convert;
138 }
139
140 void runOnOperation() override {
141 auto *ctx = &getContext();
144 enableConvert);
145 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
146 }
147};
148
149struct LowerForeachToSCFPass
150 : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
151 LowerForeachToSCFPass() = default;
152 LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
153
154 void runOnOperation() override {
155 auto *ctx = &getContext();
158 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
159 }
160};
161
162struct LowerSparseIterationToSCFPass
164 LowerSparseIterationToSCFPass> {
165 LowerSparseIterationToSCFPass() = default;
166 LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
167 default;
168
169 void runOnOperation() override {
170 auto *ctx = &getContext();
175 // The actual conversion.
176 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
177 memref::MemRefDialect, scf::SCFDialect,
178 sparse_tensor::SparseTensorDialect>();
179 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
180 IterateOp>();
181 target.addLegalOp<UnrealizedConversionCastOp>();
183
184 if (failed(applyPartialConversion(getOperation(), target,
185 std::move(patterns))))
188};
189
190struct SparseTensorConversionPass
191 : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
192 SparseTensorConversionPass() = default;
193 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
194
195 void runOnOperation() override {
196 auto *ctx = &getContext();
200 // Everything in the sparse dialect must go!
201 target.addIllegalDialect<SparseTensorDialect>();
202 // All dynamic rules below accept new function, call, return, and various
203 // tensor and bufferization operations as legal output of the rewriting
204 // provided that all sparse tensor types have been fully rewritten.
205 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
206 return converter.isSignatureLegal(op.getFunctionType());
207 });
208 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
209 return converter.isSignatureLegal(op.getCalleeType());
210 });
211 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
212 return converter.isLegal(op.getOperandTypes());
213 });
214 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
215 return converter.isLegal(op.getOperandTypes());
216 });
217 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
218 return converter.isLegal(op.getSource().getType()) &&
219 converter.isLegal(op.getDest().getType());
220 });
221 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
222 [&](tensor::ExpandShapeOp op) {
223 return converter.isLegal(op.getSrc().getType()) &&
224 converter.isLegal(op.getResult().getType());
225 });
226 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
227 [&](tensor::CollapseShapeOp op) {
228 return converter.isLegal(op.getSrc().getType()) &&
229 converter.isLegal(op.getResult().getType());
230 });
231 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232 [&](bufferization::AllocTensorOp op) {
233 return converter.isLegal(op.getType());
234 });
235 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236 [&](bufferization::DeallocTensorOp op) {
237 return converter.isLegal(op.getTensor().getType());
238 });
239 // The following operations and dialects may be introduced by the
240 // rewriting rules, and are therefore marked as legal.
241 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
242 linalg::YieldOp, tensor::ExtractOp,
243 tensor::FromElementsOp>();
244 target.addLegalDialect<
245 arith::ArithDialect, bufferization::BufferizationDialect,
246 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
248 // Populate with rules and apply rewriting rules.
249 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
250 converter);
255 if (failed(applyPartialConversion(getOperation(), target,
256 std::move(patterns))))
259};
260
261struct SparseTensorCodegenPass
262 : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
263 SparseTensorCodegenPass() = default;
264 SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
265 SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
266 createSparseDeallocs = createDeallocs;
267 enableBufferInitialization = enableInit;
268 }
269
270 void runOnOperation() override {
271 auto *ctx = &getContext();
275 // Most ops in the sparse dialect must go!
276 target.addIllegalDialect<SparseTensorDialect>();
277 target.addLegalOp<SortOp>();
278 target.addLegalOp<PushBackOp>();
279 // Storage specifier outlives sparse tensor pipeline.
280 target.addLegalOp<GetStorageSpecifierOp>();
281 target.addLegalOp<SetStorageSpecifierOp>();
282 target.addLegalOp<StorageSpecifierInitOp>();
283 // Note that tensor::FromElementsOp might be yield after lowering unpack.
284 target.addLegalOp<tensor::FromElementsOp>();
285 // All dynamic rules below accept new function, call, return, and
286 // various tensor and bufferization operations as legal output of the
287 // rewriting provided that all sparse tensor types have been fully
288 // rewritten.
289 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
290 return converter.isSignatureLegal(op.getFunctionType());
291 });
292 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
293 return converter.isSignatureLegal(op.getCalleeType());
294 });
295 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
296 return converter.isLegal(op.getOperandTypes());
297 });
298 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
299 [&](bufferization::AllocTensorOp op) {
300 return converter.isLegal(op.getType());
301 });
302 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
303 [&](bufferization::DeallocTensorOp op) {
304 return converter.isLegal(op.getTensor().getType());
305 });
306 // The following operations and dialects may be introduced by the
307 // codegen rules, and are therefore marked as legal.
308 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
309 target.addLegalDialect<
310 arith::ArithDialect, bufferization::BufferizationDialect,
311 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
312 target.addLegalOp<UnrealizedConversionCastOp>();
313 // Populate with rules and apply rewriting rules.
314 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
315 converter);
319 converter, patterns, createSparseDeallocs, enableBufferInitialization);
320 if (failed(applyPartialConversion(getOperation(), target,
321 std::move(patterns))))
324};
325
326struct SparseBufferRewritePass
327 : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
328 SparseBufferRewritePass() = default;
329 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
330 SparseBufferRewritePass(bool enableInit) {
331 enableBufferInitialization = enableInit;
332 }
333
334 void runOnOperation() override {
335 auto *ctx = &getContext();
337 populateSparseBufferRewriting(patterns, enableBufferInitialization);
339 }
340};
341
342struct SparseVectorizationPass
343 : public impl::SparseVectorizationBase<SparseVectorizationPass> {
344 SparseVectorizationPass() = default;
345 SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
346 SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
347 vectorLength = vl;
348 enableVLAVectorization = vla;
349 enableSIMDIndex32 = sidx32;
350 }
351
352 void runOnOperation() override {
353 if (vectorLength == 0)
354 return signalPassFailure();
355 auto *ctx = &getContext();
358 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
361 }
363
364struct SparseGPUCodegenPass
365 : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
366 SparseGPUCodegenPass() = default;
367 SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
368 SparseGPUCodegenPass(unsigned nT, bool enableRT) {
369 numThreads = nT;
370 enableRuntimeLibrary = enableRT;
371 }
372
373 void runOnOperation() override {
374 auto *ctx = &getContext();
376 if (numThreads == 0)
377 populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
378 else
380 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
381 }
383
384struct StorageSpecifierToLLVMPass
385 : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
386 StorageSpecifierToLLVMPass() = default;
388 void runOnOperation() override {
389 auto *ctx = &getContext();
393
394 // All ops in the sparse dialect must go!
395 target.addIllegalDialect<SparseTensorDialect>();
396 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
397 return converter.isSignatureLegal(op.getFunctionType());
398 });
399 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
400 return converter.isSignatureLegal(op.getCalleeType());
401 });
402 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
403 return converter.isLegal(op.getOperandTypes());
404 });
405 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
406
407 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
415 if (failed(applyPartialConversion(getOperation(), target,
416 std::move(patterns))))
419};
420
421} // namespace
422
423//===----------------------------------------------------------------------===//
424// Pass creation methods.
425//===----------------------------------------------------------------------===//
426
427std::unique_ptr<Pass> mlir::createSparseAssembler() {
428 return std::make_unique<SparseAssembler>();
429}
430
432 return std::make_unique<SparseReinterpretMap>();
433}
434
435std::unique_ptr<Pass>
438 options.scope = scope;
439 return std::make_unique<SparseReinterpretMap>(options);
440}
441
445 options.scope = scope;
446 options.loopOrderingStrategy = strategy;
447 return std::make_unique<SparseReinterpretMap>(options);
448}
449
451 return std::make_unique<PreSparsificationRewritePass>();
452}
453
454std::unique_ptr<Pass> mlir::createSparsificationPass() {
455 return std::make_unique<SparsificationPass>();
457
458std::unique_ptr<Pass>
460 return std::make_unique<SparsificationPass>(options);
464 return std::make_unique<StageSparseOperationsPass>();
466
468 return std::make_unique<LowerSparseOpsToForeachPass>();
469}
470
471std::unique_ptr<Pass>
472mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
473 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
474}
475
476std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
477 return std::make_unique<LowerForeachToSCFPass>();
478}
481 return std::make_unique<LowerSparseIterationToSCFPass>();
483
485 return std::make_unique<SparseTensorConversionPass>();
486}
488std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
489 return std::make_unique<SparseTensorCodegenPass>();
490}
491
492std::unique_ptr<Pass>
493mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
494 bool enableBufferInitialization) {
495 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
496 enableBufferInitialization);
497}
498
499std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
500 return std::make_unique<SparseBufferRewritePass>();
501}
502
503std::unique_ptr<Pass>
504mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
505 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
506}
508std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
509 return std::make_unique<SparseVectorizationPass>();
510}
511
512std::unique_ptr<Pass>
514 bool enableVLAVectorization,
515 bool enableSIMDIndex32) {
516 return std::make_unique<SparseVectorizationPass>(
517 vectorLength, enableVLAVectorization, enableSIMDIndex32);
518}
519
520std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
521 return std::make_unique<SparseGPUCodegenPass>();
522}
523
524std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
525 bool enableRT) {
526 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
527}
528
530 return std::make_unique<StorageSpecifierToLLVMPass>();
531}
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:218
Sparse tensor type converter into an actual buffer.
Definition Passes.h:199
Sparse tensor type converter into an opaque pointer.
Definition Passes.h:183
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
Definition Passes.h:62
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition Passes.h:45
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
Type converter for iter_space and iterator.
Definition Passes.h:169
Options for the Sparsification pass.
Definition Passes.h:109