MLIR  18.0.0git
SparseTensorPasses.cpp
Go to the documentation of this file.
1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
26 #define GEN_PASS_DEF_SPARSIFICATIONPASS
27 #define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
29 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
30 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
31 #define GEN_PASS_DEF_SPARSEVECTORIZATION
32 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
33 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
34 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 using namespace mlir::sparse_tensor;
39 
40 namespace {
41 
42 //===----------------------------------------------------------------------===//
43 // Passes implementation.
44 //===----------------------------------------------------------------------===//
45 
46 struct PreSparsificationRewritePass
47  : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
48 
49  PreSparsificationRewritePass() = default;
50  PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
51  default;
52 
53  void runOnOperation() override {
54  auto *ctx = &getContext();
55  RewritePatternSet patterns(ctx);
57  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
58  }
59 };
60 
61 struct SparsificationPass
62  : public impl::SparsificationPassBase<SparsificationPass> {
63 
64  SparsificationPass() = default;
65  SparsificationPass(const SparsificationPass &pass) = default;
66  SparsificationPass(const SparsificationOptions &options) {
67  parallelization = options.parallelizationStrategy;
68  gpuDataTransfer = options.gpuDataTransferStrategy;
69  enableIndexReduction = options.enableIndexReduction;
70  enableGPULibgen = options.enableGPULibgen;
71  enableRuntimeLibrary = options.enableRuntimeLibrary;
72  }
73 
74  void runOnOperation() override {
75  auto *ctx = &getContext();
76  // Translate strategy flags to strategy options.
77  SparsificationOptions options(parallelization, gpuDataTransfer,
78  enableIndexReduction, enableGPULibgen,
79  enableRuntimeLibrary);
80  // Apply GPU libgen (if requested), sparsification, and cleanup rewriting.
81  RewritePatternSet patterns(ctx);
82  if (enableGPULibgen) {
83  // TODO : Zero copy is disabled due to correctness bugs.Tracker #64316
84  assert(gpuDataTransfer != GPUDataTransferStrategy::kZeroCopy &&
85  "zero-copy transfer not supported with GPU libgen");
86  populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary,
87  gpuDataTransfer);
88  }
90  scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
91  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
92  }
93 };
94 
95 struct PostSparsificationRewritePass
96  : public impl::PostSparsificationRewriteBase<
97  PostSparsificationRewritePass> {
98 
99  PostSparsificationRewritePass() = default;
100  PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
101  default;
102  PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
103  enableRuntimeLibrary = enableRT;
104  enableForeach = foreach;
105  enableConvert = convert;
106  }
107 
108  void runOnOperation() override {
109  auto *ctx = &getContext();
110  RewritePatternSet patterns(ctx);
111  populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
112  enableForeach, enableConvert);
113  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
114  }
115 };
116 
117 struct SparseTensorConversionPass
118  : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
119 
120  SparseTensorConversionPass() = default;
121  SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
122  SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
123  sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
124  }
125 
126  void runOnOperation() override {
127  auto *ctx = &getContext();
128  RewritePatternSet patterns(ctx);
130  ConversionTarget target(*ctx);
131  // Everything in the sparse dialect must go!
132  target.addIllegalDialect<SparseTensorDialect>();
133  // All dynamic rules below accept new function, call, return, and various
134  // tensor and bufferization operations as legal output of the rewriting
135  // provided that all sparse tensor types have been fully rewritten.
136  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
137  return converter.isSignatureLegal(op.getFunctionType());
138  });
139  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
140  return converter.isSignatureLegal(op.getCalleeType());
141  });
142  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
143  return converter.isLegal(op.getOperandTypes());
144  });
145  target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
146  return converter.isLegal(op.getOperandTypes());
147  });
148  target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
149  return converter.isLegal(op.getSource().getType()) &&
150  converter.isLegal(op.getDest().getType());
151  });
152  target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
153  [&](tensor::ExpandShapeOp op) {
154  return converter.isLegal(op.getSrc().getType()) &&
155  converter.isLegal(op.getResult().getType());
156  });
157  target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
158  [&](tensor::CollapseShapeOp op) {
159  return converter.isLegal(op.getSrc().getType()) &&
160  converter.isLegal(op.getResult().getType());
161  });
162  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
163  [&](bufferization::AllocTensorOp op) {
164  return converter.isLegal(op.getType());
165  });
166  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
167  [&](bufferization::DeallocTensorOp op) {
168  return converter.isLegal(op.getTensor().getType());
169  });
170  // The following operations and dialects may be introduced by the
171  // rewriting rules, and are therefore marked as legal.
172  target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
173  linalg::YieldOp, tensor::ExtractOp>();
174  target.addLegalDialect<
175  arith::ArithDialect, bufferization::BufferizationDialect,
176  LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
177  // Translate strategy flags to strategy options.
179  sparseToSparseConversionStrategy(sparseToSparse));
180  // Populate with rules and apply rewriting rules.
181  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
182  converter);
183  populateCallOpTypeConversionPattern(patterns, converter);
185  target);
186  populateSparseTensorConversionPatterns(converter, patterns, options);
187  if (failed(applyPartialConversion(getOperation(), target,
188  std::move(patterns))))
189  signalPassFailure();
190  }
191 };
192 
193 struct SparseTensorCodegenPass
194  : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
195 
196  SparseTensorCodegenPass() = default;
197  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
198  SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
199  createSparseDeallocs = createDeallocs;
200  enableBufferInitialization = enableInit;
201  }
202 
203  void runOnOperation() override {
204  auto *ctx = &getContext();
205  RewritePatternSet patterns(ctx);
207  ConversionTarget target(*ctx);
208  // Most ops in the sparse dialect must go!
209  target.addIllegalDialect<SparseTensorDialect>();
210  target.addLegalOp<SortCooOp>();
211  target.addLegalOp<PushBackOp>();
212  // Storage specifier outlives sparse tensor pipeline.
213  target.addLegalOp<GetStorageSpecifierOp>();
214  target.addLegalOp<SetStorageSpecifierOp>();
215  target.addLegalOp<StorageSpecifierInitOp>();
216  // Note that tensor::FromElementsOp might be yield after lowering unpack.
217  target.addLegalOp<tensor::FromElementsOp>();
218  // All dynamic rules below accept new function, call, return, and
219  // various tensor and bufferization operations as legal output of the
220  // rewriting provided that all sparse tensor types have been fully
221  // rewritten.
222  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
223  return converter.isSignatureLegal(op.getFunctionType());
224  });
225  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
226  return converter.isSignatureLegal(op.getCalleeType());
227  });
228  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
229  return converter.isLegal(op.getOperandTypes());
230  });
231  target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232  [&](bufferization::AllocTensorOp op) {
233  return converter.isLegal(op.getType());
234  });
235  target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236  [&](bufferization::DeallocTensorOp op) {
237  return converter.isLegal(op.getTensor().getType());
238  });
239  // The following operations and dialects may be introduced by the
240  // codegen rules, and are therefore marked as legal.
241  target.addLegalOp<linalg::FillOp>();
242  target.addLegalDialect<
243  arith::ArithDialect, bufferization::BufferizationDialect,
244  complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
245  target.addLegalOp<UnrealizedConversionCastOp>();
246  // Populate with rules and apply rewriting rules.
247  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
248  converter);
250  target);
252  converter, patterns, createSparseDeallocs, enableBufferInitialization);
253  if (failed(applyPartialConversion(getOperation(), target,
254  std::move(patterns))))
255  signalPassFailure();
256  }
257 };
258 
259 struct SparseBufferRewritePass
260  : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
261 
262  SparseBufferRewritePass() = default;
263  SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
264  SparseBufferRewritePass(bool enableInit) {
265  enableBufferInitialization = enableInit;
266  }
267 
268  void runOnOperation() override {
269  auto *ctx = &getContext();
270  RewritePatternSet patterns(ctx);
271  populateSparseBufferRewriting(patterns, enableBufferInitialization);
272  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
273  }
274 };
275 
276 struct SparseVectorizationPass
277  : public impl::SparseVectorizationBase<SparseVectorizationPass> {
278 
279  SparseVectorizationPass() = default;
280  SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
281  SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
282  vectorLength = vl;
283  enableVLAVectorization = vla;
284  enableSIMDIndex32 = sidx32;
285  }
286 
287  void runOnOperation() override {
288  if (vectorLength == 0)
289  return signalPassFailure();
290  auto *ctx = &getContext();
291  RewritePatternSet patterns(ctx);
293  patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
295  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
296  }
297 };
298 
299 struct SparseGPUCodegenPass
300  : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
301 
302  SparseGPUCodegenPass() = default;
303  SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
304  SparseGPUCodegenPass(unsigned nT) { numThreads = nT; }
305 
306  void runOnOperation() override {
307  auto *ctx = &getContext();
308  RewritePatternSet patterns(ctx);
309  populateSparseGPUCodegenPatterns(patterns, numThreads);
310  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
311  }
312 };
313 
314 struct StorageSpecifierToLLVMPass
315  : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
316 
317  StorageSpecifierToLLVMPass() = default;
318 
319  void runOnOperation() override {
320  auto *ctx = &getContext();
321  ConversionTarget target(*ctx);
322  RewritePatternSet patterns(ctx);
324 
325  // All ops in the sparse dialect must go!
326  target.addIllegalDialect<SparseTensorDialect>();
327  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
328  return converter.isSignatureLegal(op.getFunctionType());
329  });
330  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
331  return converter.isSignatureLegal(op.getCalleeType());
332  });
333  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
334  return converter.isLegal(op.getOperandTypes());
335  });
336  target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
337 
338  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
339  converter);
340  populateCallOpTypeConversionPattern(patterns, converter);
342  populateReturnOpTypeConversionPattern(patterns, converter);
344  target);
345  populateStorageSpecifierToLLVMPatterns(converter, patterns);
346  if (failed(applyPartialConversion(getOperation(), target,
347  std::move(patterns))))
348  signalPassFailure();
349  }
350 };
351 
352 } // namespace
353 
354 //===----------------------------------------------------------------------===//
355 // Strategy flag methods.
356 //===----------------------------------------------------------------------===//
357 
360  switch (flag) {
361  default:
363  case 1:
365  case 2:
367  }
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // Pass creation methods.
372 //===----------------------------------------------------------------------===//
373 
375  return std::make_unique<PreSparsificationRewritePass>();
376 }
377 
378 std::unique_ptr<Pass> mlir::createSparsificationPass() {
379  return std::make_unique<SparsificationPass>();
380 }
381 
382 std::unique_ptr<Pass>
384  return std::make_unique<SparsificationPass>(options);
385 }
386 
388  return std::make_unique<PostSparsificationRewritePass>();
389 }
390 
391 std::unique_ptr<Pass>
392 mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
393  bool enableConvert) {
394  return std::make_unique<PostSparsificationRewritePass>(
395  enableRT, enableForeach, enableConvert);
396 }
397 
398 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
399  return std::make_unique<SparseTensorConversionPass>();
400 }
401 
404  return std::make_unique<SparseTensorConversionPass>(options);
405 }
406 
407 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
408  return std::make_unique<SparseTensorCodegenPass>();
409 }
410 
411 std::unique_ptr<Pass>
412 mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
413  bool enableBufferInitialization) {
414  return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
415  enableBufferInitialization);
416 }
417 
418 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
419  return std::make_unique<SparseBufferRewritePass>();
420 }
421 
422 std::unique_ptr<Pass>
423 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
424  return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
425 }
426 
427 std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
428  return std::make_unique<SparseVectorizationPass>();
429 }
430 
431 std::unique_ptr<Pass>
433  bool enableVLAVectorization,
434  bool enableSIMDIndex32) {
435  return std::make_unique<SparseVectorizationPass>(
436  vectorLength, enableVLAVectorization, enableSIMDIndex32);
437 }
438 
439 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
440  return std::make_unique<SparseGPUCodegenPass>();
441 }
442 
443 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads) {
444  return std::make_unique<SparseGPUCodegenPass>(numThreads);
445 }
446 
447 std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
448  return std::make_unique<StorageSpecifierToLLVMPass>();
449 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_type_range getOperandTypes()
Definition: Operation.h:392
Sparse tensor type converter into an actual buffer.
Definition: Passes.h:150
Sparse tensor type converter into an opaque pointer.
Definition: Passes.h:108
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Type getType() const
Return the type of this value.
Definition: Value.h:122
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
This header declares functions that assist transformations in the MemRef dialect.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createSparseGPUCodegenPass()
SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag)
Converts command-line sparse2sparse flag to the strategy enum.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT, GPUDataTransferStrategy gpuDataTransfer)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns)
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, const SparseTensorConversionOptions &options=SparseTensorConversionOptions())
Sets up sparse tensor conversion rules.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
SparseToSparseConversionStrategy
Defines a strategy for implementing sparse-to-sparse conversion.
Definition: Passes.h:120
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparsificationPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void populatePostSparsificationRewriting(RewritePatternSet &patterns, bool enableRT, bool enableForeach, bool enableConvert)
std::unique_ptr< Pass > createPostSparsificationRewritePass()
SparseTensorConversion options.
Definition: Passes.h:126
Options for the Sparsification pass.
Definition: Passes.h:63