MLIR  16.0.0git
SparseTensorPasses.cpp
Go to the documentation of this file.
1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_SPARSIFICATIONPASS
25 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
26 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
27 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 //===----------------------------------------------------------------------===//
37 // Passes implementation.
38 //===----------------------------------------------------------------------===//
39 
40 struct SparsificationPass
41  : public impl::SparsificationPassBase<SparsificationPass> {
42 
43  SparsificationPass() = default;
44  SparsificationPass(const SparsificationPass &pass) = default;
45  SparsificationPass(const SparsificationOptions &options) {
46  parallelization = options.parallelizationStrategy;
47  vectorization = options.vectorizationStrategy;
48  vectorLength = options.vectorLength;
49  enableSIMDIndex32 = options.enableSIMDIndex32;
50  enableVLAVectorization = options.enableVLAVectorization;
51  enableRuntimeLibrary = options.enableRuntimeLibrary;
52  }
53 
54  void runOnOperation() override {
55  auto *ctx = &getContext();
56  RewritePatternSet prePatterns(ctx);
57  // Translate strategy flags to strategy options.
58  SparsificationOptions options(parallelization, vectorization, vectorLength,
59  enableSIMDIndex32, enableVLAVectorization,
60  enableRuntimeLibrary);
61  // Apply pre-rewriting.
63  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
64  // Apply sparsification and vector cleanup rewriting.
65  RewritePatternSet patterns(ctx);
66  populateSparsificationPatterns(patterns, options);
68  scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
69  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
70  }
71 };
72 
73 struct SparseTensorConversionPass
74  : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
75 
76  SparseTensorConversionPass() = default;
77  SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
78  SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
79  sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
80  }
81 
82  void runOnOperation() override {
83  auto *ctx = &getContext();
84  RewritePatternSet patterns(ctx);
86  ConversionTarget target(*ctx);
87  // Everything in the sparse dialect must go!
88  target.addIllegalDialect<SparseTensorDialect>();
89  // All dynamic rules below accept new function, call, return, and various
90  // tensor and bufferization operations as legal output of the rewriting
91  // provided that all sparse tensor types have been fully rewritten.
92  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
93  return converter.isSignatureLegal(op.getFunctionType());
94  });
95  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
96  return converter.isSignatureLegal(op.getCalleeType());
97  });
98  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
99  return converter.isLegal(op.getOperandTypes());
100  });
101  target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
102  return converter.isLegal(op.getOperandTypes());
103  });
104  target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
105  return converter.isLegal(op.getSource().getType()) &&
106  converter.isLegal(op.getDest().getType());
107  });
108  target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
109  [&](tensor::ExpandShapeOp op) {
110  return converter.isLegal(op.getSrc().getType()) &&
111  converter.isLegal(op.getResult().getType());
112  });
113  target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
114  [&](tensor::CollapseShapeOp op) {
115  return converter.isLegal(op.getSrc().getType()) &&
116  converter.isLegal(op.getResult().getType());
117  });
118  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
119  [&](bufferization::AllocTensorOp op) {
120  return converter.isLegal(op.getType());
121  });
122  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
123  [&](bufferization::DeallocTensorOp op) {
124  return converter.isLegal(op.getTensor().getType());
125  });
126  // The following operations and dialects may be introduced by the
127  // rewriting rules, and are therefore marked as legal.
128  target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
129  linalg::YieldOp, tensor::ExtractOp>();
130  target.addLegalDialect<
131  arith::ArithDialect, bufferization::BufferizationDialect,
132  LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
133  // Translate strategy flags to strategy options.
135  sparseToSparseConversionStrategy(sparseToSparse));
136  // Populate with rules and apply rewriting rules.
137  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
138  converter);
139  populateCallOpTypeConversionPattern(patterns, converter);
141  target);
142  populateSparseTensorConversionPatterns(converter, patterns, options);
143  if (failed(applyPartialConversion(getOperation(), target,
144  std::move(patterns))))
145  signalPassFailure();
146  }
147 };
148 
149 struct SparseTensorCodegenPass
150  : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
151 
152  SparseTensorCodegenPass() = default;
153  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
154 
155  void runOnOperation() override {
156  auto *ctx = &getContext();
157  RewritePatternSet patterns(ctx);
159  ConversionTarget target(*ctx);
160  // Most ops in the sparse dialect must go!
161  target.addIllegalDialect<SparseTensorDialect>();
162  target.addLegalOp<SortOp>();
163  target.addLegalOp<PushBackOp>();
164  // All dynamic rules below accept new function, call, return, and various
165  // tensor and bufferization operations as legal output of the rewriting
166  // provided that all sparse tensor types have been fully rewritten.
167  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
168  return converter.isSignatureLegal(op.getFunctionType());
169  });
170  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
171  return converter.isSignatureLegal(op.getCalleeType());
172  });
173  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
174  return converter.isLegal(op.getOperandTypes());
175  });
176  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
177  [&](bufferization::AllocTensorOp op) {
178  return converter.isLegal(op.getType());
179  });
180  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
181  [&](bufferization::DeallocTensorOp op) {
182  return converter.isLegal(op.getTensor().getType());
183  });
184  // The following operations and dialects may be introduced by the
185  // codegen rules, and are therefore marked as legal.
186  target.addLegalOp<linalg::FillOp>();
187  target.addLegalDialect<arith::ArithDialect,
188  bufferization::BufferizationDialect,
189  memref::MemRefDialect, scf::SCFDialect>();
190  target.addLegalOp<UnrealizedConversionCastOp>();
191  // Populate with rules and apply rewriting rules.
192  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
193  converter);
195  target);
196  populateSparseTensorCodegenPatterns(converter, patterns);
197  if (failed(applyPartialConversion(getOperation(), target,
198  std::move(patterns))))
199  signalPassFailure();
200  }
201 };
202 
203 struct SparseBufferRewritePass
204  : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
205 
206  SparseBufferRewritePass() = default;
207  SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
208 
209  void runOnOperation() override {
210  auto *ctx = &getContext();
211  RewritePatternSet patterns(ctx);
213  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
214  }
215 };
216 
217 } // namespace
218 
219 //===----------------------------------------------------------------------===//
220 // Strategy flag methods.
221 //===----------------------------------------------------------------------===//
222 
225  switch (flag) {
226  default:
228  case 1:
230  case 2:
232  }
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Pass creation methods.
237 //===----------------------------------------------------------------------===//
238 
239 std::unique_ptr<Pass> mlir::createSparsificationPass() {
240  return std::make_unique<SparsificationPass>();
241 }
242 
243 std::unique_ptr<Pass>
245  return std::make_unique<SparsificationPass>(options);
246 }
247 
248 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
249  return std::make_unique<SparseTensorConversionPass>();
250 }
251 
254  return std::make_unique<SparseTensorConversionPass>(options);
255 }
256 
257 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
258  return std::make_unique<SparseTensorCodegenPass>();
259 }
260 
261 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
262  return std::make_unique<SparseBufferRewritePass>();
263 }
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void addLegalOp(OperationName op)
Register the given operations as legal.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
Sparse tensor type converter into an actual buffer.
Definition: Passes.h:149
std::unique_ptr< Pass > createSparseBufferRewritePass()
SparseToSparseConversionStrategy sparseToSparseStrategy
Definition: Passes.h:131
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT)
SparseToSparseConversionStrategy
Defines a strategy for implementing sparse-to-sparse conversion.
Definition: Passes.h:119
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag)
Converts command-line sparse2sparse flag to the strategy enum.
SparseTensorConversion options.
Definition: Passes.h:125
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, const SparseTensorConversionOptions &options=SparseTensorConversionOptions())
Sets up sparse tensor conversion rules.
std::unique_ptr< Pass > createSparseTensorConversionPass()
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Sparse tensor type converter into an opaque pointer.
Definition: Passes.h:96
SparseVectorizationStrategy vectorizationStrategy
Definition: Passes.h:75
static llvm::ManagedStatic< PassManagerOptions > options
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
SparseParallelizationStrategy parallelizationStrategy
Definition: Passes.h:74
This class describes a specific conversion target.
Options for the Sparsification pass.
Definition: Passes.h:61
void populateSparseBufferRewriting(RewritePatternSet &patterns)
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
std::unique_ptr< Pass > createSparsificationPass()