MLIR  22.0.0git
SparseTensorPasses.cpp
Go to the documentation of this file.
1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 using namespace mlir::sparse_tensor;
44 
45 namespace {
46 
47 //===----------------------------------------------------------------------===//
48 // Passes implementation.
49 //===----------------------------------------------------------------------===//
50 
51 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
52  SparseAssembler() = default;
53  SparseAssembler(const SparseAssembler &pass) = default;
54  SparseAssembler(bool dO) { directOut = dO; }
55 
56  void runOnOperation() override {
57  auto *ctx = &getContext();
60  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
61  }
62 };
63 
64 struct SparseReinterpretMap
65  : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66  SparseReinterpretMap() = default;
67  SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
68  SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
69  scope = options.scope;
70  loopOrderingStrategy = options.loopOrderingStrategy;
71  }
72 
73  void runOnOperation() override {
74  auto *ctx = &getContext();
76  populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
77  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
78  }
79 };
80 
81 struct PreSparsificationRewritePass
82  : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
83  PreSparsificationRewritePass() = default;
84  PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
85  default;
86 
87  void runOnOperation() override {
88  auto *ctx = &getContext();
91  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
92  }
93 };
94 
95 struct SparsificationPass
96  : public impl::SparsificationPassBase<SparsificationPass> {
97  SparsificationPass() = default;
98  SparsificationPass(const SparsificationPass &pass) = default;
99  SparsificationPass(const SparsificationOptions &options) {
100  parallelization = options.parallelizationStrategy;
101  sparseEmitStrategy = options.sparseEmitStrategy;
102  enableRuntimeLibrary = options.enableRuntimeLibrary;
103  }
104 
105  void runOnOperation() override {
106  auto *ctx = &getContext();
107  // Translate strategy flags to strategy options.
108  SparsificationOptions options(parallelization, sparseEmitStrategy,
109  enableRuntimeLibrary);
110  // Apply sparsification and cleanup rewriting.
113  scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
114  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
115  }
116 };
117 
118 struct StageSparseOperationsPass
119  : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
120  StageSparseOperationsPass() = default;
121  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
122  void runOnOperation() override {
123  auto *ctx = &getContext();
126  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
127  }
128 };
129 
130 struct LowerSparseOpsToForeachPass
131  : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
132  LowerSparseOpsToForeachPass() = default;
133  LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
134  default;
135  LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
136  enableRuntimeLibrary = enableRT;
137  enableConvert = convert;
138  }
139 
140  void runOnOperation() override {
141  auto *ctx = &getContext();
144  enableConvert);
145  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
146  }
147 };
148 
149 struct LowerForeachToSCFPass
150  : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
151  LowerForeachToSCFPass() = default;
152  LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
153 
154  void runOnOperation() override {
155  auto *ctx = &getContext();
158  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
159  }
160 };
161 
162 struct LowerSparseIterationToSCFPass
163  : public impl::LowerSparseIterationToSCFBase<
164  LowerSparseIterationToSCFPass> {
165  LowerSparseIterationToSCFPass() = default;
166  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
167  default;
168 
169  void runOnOperation() override {
170  auto *ctx = &getContext();
173  ConversionTarget target(*ctx);
174 
175  // The actual conversion.
176  target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
177  memref::MemRefDialect, scf::SCFDialect,
178  sparse_tensor::SparseTensorDialect>();
179  target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
180  IterateOp>();
181  target.addLegalOp<UnrealizedConversionCastOp>();
183 
184  if (failed(applyPartialConversion(getOperation(), target,
185  std::move(patterns))))
186  signalPassFailure();
187  }
188 };
189 
190 struct SparseTensorConversionPass
191  : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
192  SparseTensorConversionPass() = default;
193  SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
194 
195  void runOnOperation() override {
196  auto *ctx = &getContext();
199  ConversionTarget target(*ctx);
200  // Everything in the sparse dialect must go!
201  target.addIllegalDialect<SparseTensorDialect>();
202  // All dynamic rules below accept new function, call, return, and various
203  // tensor and bufferization operations as legal output of the rewriting
204  // provided that all sparse tensor types have been fully rewritten.
205  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
206  return converter.isSignatureLegal(op.getFunctionType());
207  });
208  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
209  return converter.isSignatureLegal(op.getCalleeType());
210  });
211  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
212  return converter.isLegal(op.getOperandTypes());
213  });
214  target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
215  return converter.isLegal(op.getOperandTypes());
216  });
217  target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
218  return converter.isLegal(op.getSource().getType()) &&
219  converter.isLegal(op.getDest().getType());
220  });
221  target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
222  [&](tensor::ExpandShapeOp op) {
223  return converter.isLegal(op.getSrc().getType()) &&
224  converter.isLegal(op.getResult().getType());
225  });
226  target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
227  [&](tensor::CollapseShapeOp op) {
228  return converter.isLegal(op.getSrc().getType()) &&
229  converter.isLegal(op.getResult().getType());
230  });
231  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232  [&](bufferization::AllocTensorOp op) {
233  return converter.isLegal(op.getType());
234  });
235  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236  [&](bufferization::DeallocTensorOp op) {
237  return converter.isLegal(op.getTensor().getType());
238  });
239  // The following operations and dialects may be introduced by the
240  // rewriting rules, and are therefore marked as legal.
241  target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
242  linalg::YieldOp, tensor::ExtractOp,
243  tensor::FromElementsOp>();
244  target.addLegalDialect<
245  arith::ArithDialect, bufferization::BufferizationDialect,
246  LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
247 
248  // Populate with rules and apply rewriting rules.
249  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
250  converter);
253  target);
255  if (failed(applyPartialConversion(getOperation(), target,
256  std::move(patterns))))
257  signalPassFailure();
258  }
259 };
260 
261 struct SparseTensorCodegenPass
262  : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
263  SparseTensorCodegenPass() = default;
264  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
265  SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
266  createSparseDeallocs = createDeallocs;
267  enableBufferInitialization = enableInit;
268  }
269 
270  void runOnOperation() override {
271  auto *ctx = &getContext();
274  ConversionTarget target(*ctx);
275  // Most ops in the sparse dialect must go!
276  target.addIllegalDialect<SparseTensorDialect>();
277  target.addLegalOp<SortOp>();
278  target.addLegalOp<PushBackOp>();
279  // Storage specifier outlives sparse tensor pipeline.
280  target.addLegalOp<GetStorageSpecifierOp>();
281  target.addLegalOp<SetStorageSpecifierOp>();
282  target.addLegalOp<StorageSpecifierInitOp>();
283  // Note that tensor::FromElementsOp might be yield after lowering unpack.
284  target.addLegalOp<tensor::FromElementsOp>();
285  // All dynamic rules below accept new function, call, return, and
286  // various tensor and bufferization operations as legal output of the
287  // rewriting provided that all sparse tensor types have been fully
288  // rewritten.
289  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
290  return converter.isSignatureLegal(op.getFunctionType());
291  });
292  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
293  return converter.isSignatureLegal(op.getCalleeType());
294  });
295  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
296  return converter.isLegal(op.getOperandTypes());
297  });
298  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
299  [&](bufferization::AllocTensorOp op) {
300  return converter.isLegal(op.getType());
301  });
302  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
303  [&](bufferization::DeallocTensorOp op) {
304  return converter.isLegal(op.getTensor().getType());
305  });
306  // The following operations and dialects may be introduced by the
307  // codegen rules, and are therefore marked as legal.
308  target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
309  target.addLegalDialect<
310  arith::ArithDialect, bufferization::BufferizationDialect,
311  complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
312  target.addLegalOp<UnrealizedConversionCastOp>();
313  // Populate with rules and apply rewriting rules.
314  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
315  converter);
317  target);
319  converter, patterns, createSparseDeallocs, enableBufferInitialization);
320  if (failed(applyPartialConversion(getOperation(), target,
321  std::move(patterns))))
322  signalPassFailure();
323  }
324 };
325 
326 struct SparseBufferRewritePass
327  : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
328  SparseBufferRewritePass() = default;
329  SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
330  SparseBufferRewritePass(bool enableInit) {
331  enableBufferInitialization = enableInit;
332  }
333 
334  void runOnOperation() override {
335  auto *ctx = &getContext();
337  populateSparseBufferRewriting(patterns, enableBufferInitialization);
338  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
339  }
340 };
341 
342 struct SparseVectorizationPass
343  : public impl::SparseVectorizationBase<SparseVectorizationPass> {
344  SparseVectorizationPass() = default;
345  SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
346  SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
347  vectorLength = vl;
348  enableVLAVectorization = vla;
349  enableSIMDIndex32 = sidx32;
350  }
351 
352  void runOnOperation() override {
353  if (vectorLength == 0)
354  return signalPassFailure();
355  auto *ctx = &getContext();
358  patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
360  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
361  }
362 };
363 
364 struct SparseGPUCodegenPass
365  : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
366  SparseGPUCodegenPass() = default;
367  SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
368  SparseGPUCodegenPass(unsigned nT, bool enableRT) {
369  numThreads = nT;
370  enableRuntimeLibrary = enableRT;
371  }
372 
373  void runOnOperation() override {
374  auto *ctx = &getContext();
376  if (numThreads == 0)
377  populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
378  else
380  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
381  }
382 };
383 
384 struct StorageSpecifierToLLVMPass
385  : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
386  StorageSpecifierToLLVMPass() = default;
387 
388  void runOnOperation() override {
389  auto *ctx = &getContext();
390  ConversionTarget target(*ctx);
393 
394  // All ops in the sparse dialect must go!
395  target.addIllegalDialect<SparseTensorDialect>();
396  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
397  return converter.isSignatureLegal(op.getFunctionType());
398  });
399  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
400  return converter.isSignatureLegal(op.getCalleeType());
401  });
402  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
403  return converter.isLegal(op.getOperandTypes());
404  });
405  target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
406 
407  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
408  converter);
413  target);
415  if (failed(applyPartialConversion(getOperation(), target,
416  std::move(patterns))))
417  signalPassFailure();
418  }
419 };
420 
421 } // namespace
422 
423 //===----------------------------------------------------------------------===//
424 // Pass creation methods.
425 //===----------------------------------------------------------------------===//
426 
427 std::unique_ptr<Pass> mlir::createSparseAssembler() {
428  return std::make_unique<SparseAssembler>();
429 }
430 
431 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
432  return std::make_unique<SparseReinterpretMap>();
433 }
434 
435 std::unique_ptr<Pass>
437  SparseReinterpretMapOptions options;
438  options.scope = scope;
439  return std::make_unique<SparseReinterpretMap>(options);
440 }
441 
444  SparseReinterpretMapOptions options;
445  options.scope = scope;
446  options.loopOrderingStrategy = strategy;
447  return std::make_unique<SparseReinterpretMap>(options);
448 }
449 
451  return std::make_unique<PreSparsificationRewritePass>();
452 }
453 
454 std::unique_ptr<Pass> mlir::createSparsificationPass() {
455  return std::make_unique<SparsificationPass>();
456 }
457 
458 std::unique_ptr<Pass>
460  return std::make_unique<SparsificationPass>(options);
461 }
462 
463 std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
464  return std::make_unique<StageSparseOperationsPass>();
465 }
466 
468  return std::make_unique<LowerSparseOpsToForeachPass>();
469 }
470 
471 std::unique_ptr<Pass>
472 mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
473  return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
474 }
475 
476 std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
477  return std::make_unique<LowerForeachToSCFPass>();
478 }
479 
481  return std::make_unique<LowerSparseIterationToSCFPass>();
482 }
483 
484 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
485  return std::make_unique<SparseTensorConversionPass>();
486 }
487 
488 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
489  return std::make_unique<SparseTensorCodegenPass>();
490 }
491 
492 std::unique_ptr<Pass>
493 mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
494  bool enableBufferInitialization) {
495  return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
496  enableBufferInitialization);
497 }
498 
499 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
500  return std::make_unique<SparseBufferRewritePass>();
501 }
502 
503 std::unique_ptr<Pass>
504 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
505  return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
506 }
507 
508 std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
509  return std::make_unique<SparseVectorizationPass>();
510 }
511 
512 std::unique_ptr<Pass>
514  bool enableVLAVectorization,
515  bool enableSIMDIndex32) {
516  return std::make_unique<SparseVectorizationPass>(
517  vectorLength, enableVLAVectorization, enableSIMDIndex32);
518 }
519 
520 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
521  return std::make_unique<SparseGPUCodegenPass>();
522 }
523 
524 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
525  bool enableRT) {
526  return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
527 }
528 
529 std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
530  return std::make_unique<StorageSpecifierToLLVMPass>();
531 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
Sparse tensor type converter into an actual buffer.
Definition: Passes.h:198
Sparse tensor type converter into an opaque pointer.
Definition: Passes.h:182
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
Definition: Passes.h:61
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition: Passes.h:45
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Type converter for iter_space and iterator.
Definition: Passes.h:168
Options for the Sparsification pass.
Definition: Passes.h:108