25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
51 struct SparseAssembler :
public impl::SparseAssemblerBase<SparseAssembler> {
52 SparseAssembler() =
default;
53 SparseAssembler(
const SparseAssembler &pass) =
default;
54 SparseAssembler(
bool dO) { directOut = dO; }
56 void runOnOperation()
override {
64 struct SparseReinterpretMap
65 :
public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66 SparseReinterpretMap() =
default;
67 SparseReinterpretMap(
const SparseReinterpretMap &pass) =
default;
68 SparseReinterpretMap(
const SparseReinterpretMapOptions &
options) {
72 void runOnOperation()
override {
80 struct PreSparsificationRewritePass
81 :
public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
82 PreSparsificationRewritePass() =
default;
83 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
86 void runOnOperation()
override {
94 struct SparsificationPass
95 :
public impl::SparsificationPassBase<SparsificationPass> {
96 SparsificationPass() =
default;
97 SparsificationPass(
const SparsificationPass &pass) =
default;
99 parallelization =
options.parallelizationStrategy;
100 sparseEmitStrategy =
options.sparseEmitStrategy;
101 enableRuntimeLibrary =
options.enableRuntimeLibrary;
104 void runOnOperation()
override {
108 enableRuntimeLibrary);
112 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
117 struct StageSparseOperationsPass
118 :
public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
119 StageSparseOperationsPass() =
default;
120 StageSparseOperationsPass(
const StageSparseOperationsPass &pass) =
default;
121 void runOnOperation()
override {
129 struct LowerSparseOpsToForeachPass
130 :
public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
131 LowerSparseOpsToForeachPass() =
default;
132 LowerSparseOpsToForeachPass(
const LowerSparseOpsToForeachPass &pass) =
134 LowerSparseOpsToForeachPass(
bool enableRT,
bool convert) {
135 enableRuntimeLibrary = enableRT;
136 enableConvert = convert;
139 void runOnOperation()
override {
148 struct LowerForeachToSCFPass
149 :
public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
150 LowerForeachToSCFPass() =
default;
151 LowerForeachToSCFPass(
const LowerForeachToSCFPass &pass) =
default;
153 void runOnOperation()
override {
161 struct LowerSparseIterationToSCFPass
162 :
public impl::LowerSparseIterationToSCFBase<
163 LowerSparseIterationToSCFPass> {
164 LowerSparseIterationToSCFPass() =
default;
165 LowerSparseIterationToSCFPass(
const LowerSparseIterationToSCFPass &) =
168 void runOnOperation()
override {
175 target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
179 std::move(patterns))))
184 struct SparseTensorConversionPass
185 :
public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
186 SparseTensorConversionPass() =
default;
187 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
189 void runOnOperation()
override {
195 target.addIllegalDialect<SparseTensorDialect>();
199 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
202 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
205 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
206 return converter.
isLegal(op.getOperandTypes());
208 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
209 return converter.
isLegal(op.getOperandTypes());
211 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
212 return converter.
isLegal(op.getSource().getType()) &&
213 converter.
isLegal(op.getDest().getType());
215 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
216 [&](tensor::ExpandShapeOp op) {
217 return converter.
isLegal(op.getSrc().getType()) &&
218 converter.
isLegal(op.getResult().getType());
220 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
221 [&](tensor::CollapseShapeOp op) {
222 return converter.
isLegal(op.getSrc().getType()) &&
223 converter.
isLegal(op.getResult().getType());
225 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
226 [&](bufferization::AllocTensorOp op) {
227 return converter.
isLegal(op.getType());
229 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
230 [&](bufferization::DeallocTensorOp op) {
231 return converter.
isLegal(op.getTensor().getType());
235 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
236 linalg::YieldOp, tensor::ExtractOp,
237 tensor::FromElementsOp>();
238 target.addLegalDialect<
239 arith::ArithDialect, bufferization::BufferizationDialect,
240 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
243 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
250 std::move(patterns))))
255 struct SparseTensorCodegenPass
256 :
public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
257 SparseTensorCodegenPass() =
default;
258 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
259 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
260 createSparseDeallocs = createDeallocs;
261 enableBufferInitialization = enableInit;
264 void runOnOperation()
override {
270 target.addIllegalDialect<SparseTensorDialect>();
271 target.addLegalOp<SortOp>();
272 target.addLegalOp<PushBackOp>();
274 target.addLegalOp<GetStorageSpecifierOp>();
275 target.addLegalOp<SetStorageSpecifierOp>();
276 target.addLegalOp<StorageSpecifierInitOp>();
278 target.addLegalOp<tensor::FromElementsOp>();
283 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
286 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
289 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
290 return converter.
isLegal(op.getOperandTypes());
292 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
293 [&](bufferization::AllocTensorOp op) {
294 return converter.
isLegal(op.getType());
296 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
297 [&](bufferization::DeallocTensorOp op) {
298 return converter.
isLegal(op.getTensor().getType());
302 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
303 target.addLegalDialect<
304 arith::ArithDialect, bufferization::BufferizationDialect,
305 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
306 target.addLegalOp<UnrealizedConversionCastOp>();
308 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
313 converter, patterns, createSparseDeallocs, enableBufferInitialization);
315 std::move(patterns))))
320 struct SparseBufferRewritePass
321 :
public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
322 SparseBufferRewritePass() =
default;
323 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
324 SparseBufferRewritePass(
bool enableInit) {
325 enableBufferInitialization = enableInit;
328 void runOnOperation()
override {
336 struct SparseVectorizationPass
337 :
public impl::SparseVectorizationBase<SparseVectorizationPass> {
338 SparseVectorizationPass() =
default;
339 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
340 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
342 enableVLAVectorization = vla;
343 enableSIMDIndex32 = sidx32;
346 void runOnOperation()
override {
347 if (vectorLength == 0)
348 return signalPassFailure();
352 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
358 struct SparseGPUCodegenPass
359 :
public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
360 SparseGPUCodegenPass() =
default;
361 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
362 SparseGPUCodegenPass(
unsigned nT,
bool enableRT) {
364 enableRuntimeLibrary = enableRT;
367 void runOnOperation()
override {
378 struct StorageSpecifierToLLVMPass
379 :
public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
380 StorageSpecifierToLLVMPass() =
default;
382 void runOnOperation()
override {
389 target.addIllegalDialect<SparseTensorDialect>();
390 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
393 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
396 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
397 return converter.
isLegal(op.getOperandTypes());
399 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
401 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
410 std::move(patterns))))
422 return std::make_unique<SparseAssembler>();
426 return std::make_unique<SparseReinterpretMap>();
429 std::unique_ptr<Pass>
431 SparseReinterpretMapOptions
options;
433 return std::make_unique<SparseReinterpretMap>(
options);
437 return std::make_unique<PreSparsificationRewritePass>();
441 return std::make_unique<SparsificationPass>();
444 std::unique_ptr<Pass>
446 return std::make_unique<SparsificationPass>(
options);
450 return std::make_unique<StageSparseOperationsPass>();
454 return std::make_unique<LowerSparseOpsToForeachPass>();
457 std::unique_ptr<Pass>
459 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
463 return std::make_unique<LowerForeachToSCFPass>();
467 return std::make_unique<LowerSparseIterationToSCFPass>();
471 return std::make_unique<SparseTensorConversionPass>();
475 return std::make_unique<SparseTensorCodegenPass>();
478 std::unique_ptr<Pass>
480 bool enableBufferInitialization) {
481 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
482 enableBufferInitialization);
486 return std::make_unique<SparseBufferRewritePass>();
489 std::unique_ptr<Pass>
491 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
495 return std::make_unique<SparseVectorizationPass>();
498 std::unique_ptr<Pass>
500 bool enableVLAVectorization,
501 bool enableSIMDIndex32) {
502 return std::make_unique<SparseVectorizationPass>(
503 vectorLength, enableVLAVectorization, enableSIMDIndex32);
507 return std::make_unique<SparseGPUCodegenPass>();
512 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
516 return std::make_unique<StorageSpecifierToLLVMPass>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Type converter for iter_space and iterator.
Options for the Sparsification pass.