MLIR  19.0.0git
SparseTensorPasses.cpp
Go to the documentation of this file.
1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEENCODINGPROPAGATION
27 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
28 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
29 #define GEN_PASS_DEF_SPARSIFICATIONPASS
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 using namespace mlir::sparse_tensor;
44 
45 namespace {
46 
47 //===----------------------------------------------------------------------===//
48 // Passes implementation.
49 //===----------------------------------------------------------------------===//
50 
51 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
52  SparseAssembler() = default;
53  SparseAssembler(const SparseAssembler &pass) = default;
54  SparseAssembler(bool dO) { directOut = dO; }
55 
56  void runOnOperation() override {
57  auto *ctx = &getContext();
58  RewritePatternSet patterns(ctx);
59  populateSparseAssembler(patterns, directOut);
60  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
61  }
62 };
63 
64 struct SparseEncodingPropagation
65  : public impl::SparseEncodingPropagationBase<SparseEncodingPropagation> {
66  SparseEncodingPropagation() = default;
67  SparseEncodingPropagation(const SparseEncodingPropagation &pass) = default;
68 
69  void runOnOperation() override {}
70 };
71 
72 struct SparseReinterpretMap
73  : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
74  SparseReinterpretMap() = default;
75  SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
76  SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
77  scope = options.scope;
78  }
79 
80  void runOnOperation() override {
81  auto *ctx = &getContext();
82  RewritePatternSet patterns(ctx);
83  populateSparseReinterpretMap(patterns, scope);
84  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
85  }
86 };
87 
88 struct PreSparsificationRewritePass
89  : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
90  PreSparsificationRewritePass() = default;
91  PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
92  default;
93 
94  void runOnOperation() override {
95  auto *ctx = &getContext();
96  RewritePatternSet patterns(ctx);
98  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
99  }
100 };
101 
102 struct SparsificationPass
103  : public impl::SparsificationPassBase<SparsificationPass> {
104  SparsificationPass() = default;
105  SparsificationPass(const SparsificationPass &pass) = default;
106  SparsificationPass(const SparsificationOptions &options) {
107  parallelization = options.parallelizationStrategy;
108  sparseEmitStrategy = options.sparseEmitStrategy;
109  enableRuntimeLibrary = options.enableRuntimeLibrary;
110  }
111 
112  void runOnOperation() override {
113  auto *ctx = &getContext();
114  // Translate strategy flags to strategy options.
115  SparsificationOptions options(parallelization, sparseEmitStrategy,
116  enableRuntimeLibrary);
117  // Apply sparsification and cleanup rewriting.
118  RewritePatternSet patterns(ctx);
120  scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
121  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
122  }
123 };
124 
125 struct StageSparseOperationsPass
126  : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
127  StageSparseOperationsPass() = default;
128  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
129  void runOnOperation() override {
130  auto *ctx = &getContext();
131  RewritePatternSet patterns(ctx);
133  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
134  }
135 };
136 
137 struct LowerSparseOpsToForeachPass
138  : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
139  LowerSparseOpsToForeachPass() = default;
140  LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
141  default;
142  LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
143  enableRuntimeLibrary = enableRT;
144  enableConvert = convert;
145  }
146 
147  void runOnOperation() override {
148  auto *ctx = &getContext();
149  RewritePatternSet patterns(ctx);
150  populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
151  enableConvert);
152  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
153  }
154 };
155 
156 struct LowerForeachToSCFPass
157  : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
158  LowerForeachToSCFPass() = default;
159  LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
160 
161  void runOnOperation() override {
162  auto *ctx = &getContext();
163  RewritePatternSet patterns(ctx);
165  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
166  }
167 };
168 
169 struct SparseTensorConversionPass
170  : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
171  SparseTensorConversionPass() = default;
172  SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
173 
174  void runOnOperation() override {
175  auto *ctx = &getContext();
176  RewritePatternSet patterns(ctx);
178  ConversionTarget target(*ctx);
179  // Everything in the sparse dialect must go!
180  target.addIllegalDialect<SparseTensorDialect>();
181  // All dynamic rules below accept new function, call, return, and various
182  // tensor and bufferization operations as legal output of the rewriting
183  // provided that all sparse tensor types have been fully rewritten.
184  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
185  return converter.isSignatureLegal(op.getFunctionType());
186  });
187  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
188  return converter.isSignatureLegal(op.getCalleeType());
189  });
190  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
191  return converter.isLegal(op.getOperandTypes());
192  });
193  target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
194  return converter.isLegal(op.getOperandTypes());
195  });
196  target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
197  return converter.isLegal(op.getSource().getType()) &&
198  converter.isLegal(op.getDest().getType());
199  });
200  target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
201  [&](tensor::ExpandShapeOp op) {
202  return converter.isLegal(op.getSrc().getType()) &&
203  converter.isLegal(op.getResult().getType());
204  });
205  target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
206  [&](tensor::CollapseShapeOp op) {
207  return converter.isLegal(op.getSrc().getType()) &&
208  converter.isLegal(op.getResult().getType());
209  });
210  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
211  [&](bufferization::AllocTensorOp op) {
212  return converter.isLegal(op.getType());
213  });
214  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
215  [&](bufferization::DeallocTensorOp op) {
216  return converter.isLegal(op.getTensor().getType());
217  });
218  // The following operations and dialects may be introduced by the
219  // rewriting rules, and are therefore marked as legal.
220  target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
221  linalg::YieldOp, tensor::ExtractOp,
222  tensor::FromElementsOp>();
223  target.addLegalDialect<
224  arith::ArithDialect, bufferization::BufferizationDialect,
225  LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
226 
227  // Populate with rules and apply rewriting rules.
228  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
229  converter);
230  populateCallOpTypeConversionPattern(patterns, converter);
232  target);
233  populateSparseTensorConversionPatterns(converter, patterns);
234  if (failed(applyPartialConversion(getOperation(), target,
235  std::move(patterns))))
236  signalPassFailure();
237  }
238 };
239 
240 struct SparseTensorCodegenPass
241  : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
242  SparseTensorCodegenPass() = default;
243  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
244  SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
245  createSparseDeallocs = createDeallocs;
246  enableBufferInitialization = enableInit;
247  }
248 
249  void runOnOperation() override {
250  auto *ctx = &getContext();
251  RewritePatternSet patterns(ctx);
253  ConversionTarget target(*ctx);
254  // Most ops in the sparse dialect must go!
255  target.addIllegalDialect<SparseTensorDialect>();
256  target.addLegalOp<SortOp>();
257  target.addLegalOp<PushBackOp>();
258  // Storage specifier outlives sparse tensor pipeline.
259  target.addLegalOp<GetStorageSpecifierOp>();
260  target.addLegalOp<SetStorageSpecifierOp>();
261  target.addLegalOp<StorageSpecifierInitOp>();
262  // Note that tensor::FromElementsOp might be yield after lowering unpack.
263  target.addLegalOp<tensor::FromElementsOp>();
264  // All dynamic rules below accept new function, call, return, and
265  // various tensor and bufferization operations as legal output of the
266  // rewriting provided that all sparse tensor types have been fully
267  // rewritten.
268  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
269  return converter.isSignatureLegal(op.getFunctionType());
270  });
271  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
272  return converter.isSignatureLegal(op.getCalleeType());
273  });
274  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
275  return converter.isLegal(op.getOperandTypes());
276  });
277  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
278  [&](bufferization::AllocTensorOp op) {
279  return converter.isLegal(op.getType());
280  });
281  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
282  [&](bufferization::DeallocTensorOp op) {
283  return converter.isLegal(op.getTensor().getType());
284  });
285  // The following operations and dialects may be introduced by the
286  // codegen rules, and are therefore marked as legal.
287  target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
288  target.addLegalDialect<
289  arith::ArithDialect, bufferization::BufferizationDialect,
290  complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
291  target.addLegalOp<UnrealizedConversionCastOp>();
292  // Populate with rules and apply rewriting rules.
293  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
294  converter);
296  target);
298  converter, patterns, createSparseDeallocs, enableBufferInitialization);
299  if (failed(applyPartialConversion(getOperation(), target,
300  std::move(patterns))))
301  signalPassFailure();
302  }
303 };
304 
305 struct SparseBufferRewritePass
306  : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
307  SparseBufferRewritePass() = default;
308  SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
309  SparseBufferRewritePass(bool enableInit) {
310  enableBufferInitialization = enableInit;
311  }
312 
313  void runOnOperation() override {
314  auto *ctx = &getContext();
315  RewritePatternSet patterns(ctx);
316  populateSparseBufferRewriting(patterns, enableBufferInitialization);
317  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
318  }
319 };
320 
321 struct SparseVectorizationPass
322  : public impl::SparseVectorizationBase<SparseVectorizationPass> {
323  SparseVectorizationPass() = default;
324  SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
325  SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
326  vectorLength = vl;
327  enableVLAVectorization = vla;
328  enableSIMDIndex32 = sidx32;
329  }
330 
331  void runOnOperation() override {
332  if (vectorLength == 0)
333  return signalPassFailure();
334  auto *ctx = &getContext();
335  RewritePatternSet patterns(ctx);
337  patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
339  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
340  }
341 };
342 
343 struct SparseGPUCodegenPass
344  : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
345  SparseGPUCodegenPass() = default;
346  SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
347  SparseGPUCodegenPass(unsigned nT, bool enableRT) {
348  numThreads = nT;
349  enableRuntimeLibrary = enableRT;
350  }
351 
352  void runOnOperation() override {
353  auto *ctx = &getContext();
354  RewritePatternSet patterns(ctx);
355  if (numThreads == 0)
356  populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
357  else
358  populateSparseGPUCodegenPatterns(patterns, numThreads);
359  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
360  }
361 };
362 
363 struct StorageSpecifierToLLVMPass
364  : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
365  StorageSpecifierToLLVMPass() = default;
366 
367  void runOnOperation() override {
368  auto *ctx = &getContext();
369  ConversionTarget target(*ctx);
370  RewritePatternSet patterns(ctx);
372 
373  // All ops in the sparse dialect must go!
374  target.addIllegalDialect<SparseTensorDialect>();
375  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
376  return converter.isSignatureLegal(op.getFunctionType());
377  });
378  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
379  return converter.isSignatureLegal(op.getCalleeType());
380  });
381  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
382  return converter.isLegal(op.getOperandTypes());
383  });
384  target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
385 
386  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
387  converter);
388  populateCallOpTypeConversionPattern(patterns, converter);
390  populateReturnOpTypeConversionPattern(patterns, converter);
392  target);
393  populateStorageSpecifierToLLVMPatterns(converter, patterns);
394  if (failed(applyPartialConversion(getOperation(), target,
395  std::move(patterns))))
396  signalPassFailure();
397  }
398 };
399 
400 } // namespace
401 
402 //===----------------------------------------------------------------------===//
403 // Pass creation methods.
404 //===----------------------------------------------------------------------===//
405 
406 std::unique_ptr<Pass> mlir::createSparseAssembler() {
407  return std::make_unique<SparseAssembler>();
408 }
409 
411  return std::make_unique<SparseEncodingPropagation>();
412 }
413 
414 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
415  return std::make_unique<SparseReinterpretMap>();
416 }
417 
418 std::unique_ptr<Pass>
420  SparseReinterpretMapOptions options;
421  options.scope = scope;
422  return std::make_unique<SparseReinterpretMap>(options);
423 }
424 
426  return std::make_unique<PreSparsificationRewritePass>();
427 }
428 
429 std::unique_ptr<Pass> mlir::createSparsificationPass() {
430  return std::make_unique<SparsificationPass>();
431 }
432 
433 std::unique_ptr<Pass>
435  return std::make_unique<SparsificationPass>(options);
436 }
437 
438 std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
439  return std::make_unique<StageSparseOperationsPass>();
440 }
441 
443  return std::make_unique<LowerSparseOpsToForeachPass>();
444 }
445 
446 std::unique_ptr<Pass>
447 mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
448  return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
449 }
450 
451 std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
452  return std::make_unique<LowerForeachToSCFPass>();
453 }
454 
455 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
456  return std::make_unique<SparseTensorConversionPass>();
457 }
458 
459 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
460  return std::make_unique<SparseTensorCodegenPass>();
461 }
462 
463 std::unique_ptr<Pass>
464 mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
465  bool enableBufferInitialization) {
466  return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
467  enableBufferInitialization);
468 }
469 
470 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
471  return std::make_unique<SparseBufferRewritePass>();
472 }
473 
474 std::unique_ptr<Pass>
475 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
476  return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
477 }
478 
479 std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
480  return std::make_unique<SparseVectorizationPass>();
481 }
482 
483 std::unique_ptr<Pass>
485  bool enableVLAVectorization,
486  bool enableSIMDIndex32) {
487  return std::make_unique<SparseVectorizationPass>(
488  vectorLength, enableVLAVectorization, enableSIMDIndex32);
489 }
490 
491 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
492  return std::make_unique<SparseGPUCodegenPass>();
493 }
494 
495 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
496  bool enableRT) {
497  return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
498 }
499 
500 std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
501  return std::make_unique<StorageSpecifierToLLVMPass>();
502 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_type_range getOperandTypes()
Definition: Operation.h:392
Sparse tensor type converter into an actual buffer.
Definition: Passes.h:173
Sparse tensor type converter into an opaque pointer.
Definition: Passes.h:157
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition: Passes.h:44
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparsificationPass()
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
std::unique_ptr< Pass > createSparseEncodingPropagationPass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Options for the Sparsification pass.
Definition: Passes.h:97