MLIR  20.0.0git
SparseTensorPasses.cpp
Go to the documentation of this file.
1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 using namespace mlir::sparse_tensor;
44 
45 namespace {
46 
47 //===----------------------------------------------------------------------===//
48 // Passes implementation.
49 //===----------------------------------------------------------------------===//
50 
51 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
52  SparseAssembler() = default;
53  SparseAssembler(const SparseAssembler &pass) = default;
54  SparseAssembler(bool dO) { directOut = dO; }
55 
56  void runOnOperation() override {
57  auto *ctx = &getContext();
58  RewritePatternSet patterns(ctx);
59  populateSparseAssembler(patterns, directOut);
60  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
61  }
62 };
63 
64 struct SparseReinterpretMap
65  : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66  SparseReinterpretMap() = default;
67  SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
68  SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
69  scope = options.scope;
70  }
71 
72  void runOnOperation() override {
73  auto *ctx = &getContext();
74  RewritePatternSet patterns(ctx);
75  populateSparseReinterpretMap(patterns, scope);
76  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
77  }
78 };
79 
80 struct PreSparsificationRewritePass
81  : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
82  PreSparsificationRewritePass() = default;
83  PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
84  default;
85 
86  void runOnOperation() override {
87  auto *ctx = &getContext();
88  RewritePatternSet patterns(ctx);
90  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
91  }
92 };
93 
94 struct SparsificationPass
95  : public impl::SparsificationPassBase<SparsificationPass> {
96  SparsificationPass() = default;
97  SparsificationPass(const SparsificationPass &pass) = default;
98  SparsificationPass(const SparsificationOptions &options) {
99  parallelization = options.parallelizationStrategy;
100  sparseEmitStrategy = options.sparseEmitStrategy;
101  enableRuntimeLibrary = options.enableRuntimeLibrary;
102  }
103 
104  void runOnOperation() override {
105  auto *ctx = &getContext();
106  // Translate strategy flags to strategy options.
107  SparsificationOptions options(parallelization, sparseEmitStrategy,
108  enableRuntimeLibrary);
109  // Apply sparsification and cleanup rewriting.
110  RewritePatternSet patterns(ctx);
112  scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
113  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
114  }
115 };
116 
117 struct StageSparseOperationsPass
118  : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
119  StageSparseOperationsPass() = default;
120  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
121  void runOnOperation() override {
122  auto *ctx = &getContext();
123  RewritePatternSet patterns(ctx);
125  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
126  }
127 };
128 
129 struct LowerSparseOpsToForeachPass
130  : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
131  LowerSparseOpsToForeachPass() = default;
132  LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
133  default;
134  LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
135  enableRuntimeLibrary = enableRT;
136  enableConvert = convert;
137  }
138 
139  void runOnOperation() override {
140  auto *ctx = &getContext();
141  RewritePatternSet patterns(ctx);
142  populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
143  enableConvert);
144  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
145  }
146 };
147 
148 struct LowerForeachToSCFPass
149  : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
150  LowerForeachToSCFPass() = default;
151  LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
152 
153  void runOnOperation() override {
154  auto *ctx = &getContext();
155  RewritePatternSet patterns(ctx);
157  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
158  }
159 };
160 
161 struct LowerSparseIterationToSCFPass
162  : public impl::LowerSparseIterationToSCFBase<
163  LowerSparseIterationToSCFPass> {
164  LowerSparseIterationToSCFPass() = default;
165  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
166  default;
167 
168  void runOnOperation() override {
169  auto *ctx = &getContext();
170  RewritePatternSet patterns(ctx);
172  ConversionTarget target(*ctx);
173 
174  // The actual conversion.
175  target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
176  populateLowerSparseIterationToSCFPatterns(converter, patterns);
177 
178  if (failed(applyPartialOneToNConversion(getOperation(), converter,
179  std::move(patterns))))
180  signalPassFailure();
181  }
182 };
183 
184 struct SparseTensorConversionPass
185  : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
186  SparseTensorConversionPass() = default;
187  SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
188 
189  void runOnOperation() override {
190  auto *ctx = &getContext();
191  RewritePatternSet patterns(ctx);
193  ConversionTarget target(*ctx);
194  // Everything in the sparse dialect must go!
195  target.addIllegalDialect<SparseTensorDialect>();
196  // All dynamic rules below accept new function, call, return, and various
197  // tensor and bufferization operations as legal output of the rewriting
198  // provided that all sparse tensor types have been fully rewritten.
199  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
200  return converter.isSignatureLegal(op.getFunctionType());
201  });
202  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
203  return converter.isSignatureLegal(op.getCalleeType());
204  });
205  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
206  return converter.isLegal(op.getOperandTypes());
207  });
208  target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
209  return converter.isLegal(op.getOperandTypes());
210  });
211  target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
212  return converter.isLegal(op.getSource().getType()) &&
213  converter.isLegal(op.getDest().getType());
214  });
215  target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
216  [&](tensor::ExpandShapeOp op) {
217  return converter.isLegal(op.getSrc().getType()) &&
218  converter.isLegal(op.getResult().getType());
219  });
220  target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
221  [&](tensor::CollapseShapeOp op) {
222  return converter.isLegal(op.getSrc().getType()) &&
223  converter.isLegal(op.getResult().getType());
224  });
225  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
226  [&](bufferization::AllocTensorOp op) {
227  return converter.isLegal(op.getType());
228  });
229  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
230  [&](bufferization::DeallocTensorOp op) {
231  return converter.isLegal(op.getTensor().getType());
232  });
233  // The following operations and dialects may be introduced by the
234  // rewriting rules, and are therefore marked as legal.
235  target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
236  linalg::YieldOp, tensor::ExtractOp,
237  tensor::FromElementsOp>();
238  target.addLegalDialect<
239  arith::ArithDialect, bufferization::BufferizationDialect,
240  LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
241 
242  // Populate with rules and apply rewriting rules.
243  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
244  converter);
245  populateCallOpTypeConversionPattern(patterns, converter);
247  target);
248  populateSparseTensorConversionPatterns(converter, patterns);
249  if (failed(applyPartialConversion(getOperation(), target,
250  std::move(patterns))))
251  signalPassFailure();
252  }
253 };
254 
255 struct SparseTensorCodegenPass
256  : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
257  SparseTensorCodegenPass() = default;
258  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
259  SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
260  createSparseDeallocs = createDeallocs;
261  enableBufferInitialization = enableInit;
262  }
263 
264  void runOnOperation() override {
265  auto *ctx = &getContext();
266  RewritePatternSet patterns(ctx);
268  ConversionTarget target(*ctx);
269  // Most ops in the sparse dialect must go!
270  target.addIllegalDialect<SparseTensorDialect>();
271  target.addLegalOp<SortOp>();
272  target.addLegalOp<PushBackOp>();
273  // Storage specifier outlives sparse tensor pipeline.
274  target.addLegalOp<GetStorageSpecifierOp>();
275  target.addLegalOp<SetStorageSpecifierOp>();
276  target.addLegalOp<StorageSpecifierInitOp>();
277  // Note that tensor::FromElementsOp might be yield after lowering unpack.
278  target.addLegalOp<tensor::FromElementsOp>();
279  // All dynamic rules below accept new function, call, return, and
280  // various tensor and bufferization operations as legal output of the
281  // rewriting provided that all sparse tensor types have been fully
282  // rewritten.
283  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
284  return converter.isSignatureLegal(op.getFunctionType());
285  });
286  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
287  return converter.isSignatureLegal(op.getCalleeType());
288  });
289  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
290  return converter.isLegal(op.getOperandTypes());
291  });
292  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
293  [&](bufferization::AllocTensorOp op) {
294  return converter.isLegal(op.getType());
295  });
296  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
297  [&](bufferization::DeallocTensorOp op) {
298  return converter.isLegal(op.getTensor().getType());
299  });
300  // The following operations and dialects may be introduced by the
301  // codegen rules, and are therefore marked as legal.
302  target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
303  target.addLegalDialect<
304  arith::ArithDialect, bufferization::BufferizationDialect,
305  complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
306  target.addLegalOp<UnrealizedConversionCastOp>();
307  // Populate with rules and apply rewriting rules.
308  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
309  converter);
311  target);
313  converter, patterns, createSparseDeallocs, enableBufferInitialization);
314  if (failed(applyPartialConversion(getOperation(), target,
315  std::move(patterns))))
316  signalPassFailure();
317  }
318 };
319 
320 struct SparseBufferRewritePass
321  : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
322  SparseBufferRewritePass() = default;
323  SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
324  SparseBufferRewritePass(bool enableInit) {
325  enableBufferInitialization = enableInit;
326  }
327 
328  void runOnOperation() override {
329  auto *ctx = &getContext();
330  RewritePatternSet patterns(ctx);
331  populateSparseBufferRewriting(patterns, enableBufferInitialization);
332  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
333  }
334 };
335 
336 struct SparseVectorizationPass
337  : public impl::SparseVectorizationBase<SparseVectorizationPass> {
338  SparseVectorizationPass() = default;
339  SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
340  SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
341  vectorLength = vl;
342  enableVLAVectorization = vla;
343  enableSIMDIndex32 = sidx32;
344  }
345 
346  void runOnOperation() override {
347  if (vectorLength == 0)
348  return signalPassFailure();
349  auto *ctx = &getContext();
350  RewritePatternSet patterns(ctx);
352  patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
354  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
355  }
356 };
357 
358 struct SparseGPUCodegenPass
359  : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
360  SparseGPUCodegenPass() = default;
361  SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
362  SparseGPUCodegenPass(unsigned nT, bool enableRT) {
363  numThreads = nT;
364  enableRuntimeLibrary = enableRT;
365  }
366 
367  void runOnOperation() override {
368  auto *ctx = &getContext();
369  RewritePatternSet patterns(ctx);
370  if (numThreads == 0)
371  populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
372  else
373  populateSparseGPUCodegenPatterns(patterns, numThreads);
374  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
375  }
376 };
377 
378 struct StorageSpecifierToLLVMPass
379  : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
380  StorageSpecifierToLLVMPass() = default;
381 
382  void runOnOperation() override {
383  auto *ctx = &getContext();
384  ConversionTarget target(*ctx);
385  RewritePatternSet patterns(ctx);
387 
388  // All ops in the sparse dialect must go!
389  target.addIllegalDialect<SparseTensorDialect>();
390  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
391  return converter.isSignatureLegal(op.getFunctionType());
392  });
393  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
394  return converter.isSignatureLegal(op.getCalleeType());
395  });
396  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
397  return converter.isLegal(op.getOperandTypes());
398  });
399  target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
400 
401  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
402  converter);
403  populateCallOpTypeConversionPattern(patterns, converter);
405  populateReturnOpTypeConversionPattern(patterns, converter);
407  target);
408  populateStorageSpecifierToLLVMPatterns(converter, patterns);
409  if (failed(applyPartialConversion(getOperation(), target,
410  std::move(patterns))))
411  signalPassFailure();
412  }
413 };
414 
415 } // namespace
416 
417 //===----------------------------------------------------------------------===//
418 // Pass creation methods.
419 //===----------------------------------------------------------------------===//
420 
421 std::unique_ptr<Pass> mlir::createSparseAssembler() {
422  return std::make_unique<SparseAssembler>();
423 }
424 
425 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
426  return std::make_unique<SparseReinterpretMap>();
427 }
428 
429 std::unique_ptr<Pass>
431  SparseReinterpretMapOptions options;
432  options.scope = scope;
433  return std::make_unique<SparseReinterpretMap>(options);
434 }
435 
437  return std::make_unique<PreSparsificationRewritePass>();
438 }
439 
440 std::unique_ptr<Pass> mlir::createSparsificationPass() {
441  return std::make_unique<SparsificationPass>();
442 }
443 
444 std::unique_ptr<Pass>
446  return std::make_unique<SparsificationPass>(options);
447 }
448 
449 std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
450  return std::make_unique<StageSparseOperationsPass>();
451 }
452 
454  return std::make_unique<LowerSparseOpsToForeachPass>();
455 }
456 
457 std::unique_ptr<Pass>
458 mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
459  return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
460 }
461 
462 std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
463  return std::make_unique<LowerForeachToSCFPass>();
464 }
465 
467  return std::make_unique<LowerSparseIterationToSCFPass>();
468 }
469 
470 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
471  return std::make_unique<SparseTensorConversionPass>();
472 }
473 
474 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
475  return std::make_unique<SparseTensorCodegenPass>();
476 }
477 
478 std::unique_ptr<Pass>
479 mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
480  bool enableBufferInitialization) {
481  return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
482  enableBufferInitialization);
483 }
484 
485 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
486  return std::make_unique<SparseBufferRewritePass>();
487 }
488 
489 std::unique_ptr<Pass>
490 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
491  return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
492 }
493 
494 std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
495  return std::make_unique<SparseVectorizationPass>();
496 }
497 
498 std::unique_ptr<Pass>
500  bool enableVLAVectorization,
501  bool enableSIMDIndex32) {
502  return std::make_unique<SparseVectorizationPass>(
503  vectorLength, enableVLAVectorization, enableSIMDIndex32);
504 }
505 
506 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
507  return std::make_unique<SparseGPUCodegenPass>();
508 }
509 
510 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
511  bool enableRT) {
512  return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
513 }
514 
515 std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
516  return std::make_unique<StorageSpecifierToLLVMPass>();
517 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
Sparse tensor type converter into an actual buffer.
Definition: Passes.h:183
Sparse tensor type converter into an opaque pointer.
Definition: Passes.h:167
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition: Passes.h:45
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Type converter for iter_space and iterator.
Definition: Passes.h:153
Options for the Sparsification pass.
Definition: Passes.h:93