25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
30 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
31 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
32 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
33 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
34 #define GEN_PASS_DEF_SPARSEVECTORIZATION
35 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
36 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
37 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
38 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
50 struct SparseAssembler :
public impl::SparseAssemblerBase<SparseAssembler> {
51 SparseAssembler() =
default;
52 SparseAssembler(
const SparseAssembler &pass) =
default;
53 SparseAssembler(
bool dO) { directOut = dO; }
55 void runOnOperation()
override {
63 struct SparseReinterpretMap
64 :
public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
65 SparseReinterpretMap() =
default;
66 SparseReinterpretMap(
const SparseReinterpretMap &pass) =
default;
67 SparseReinterpretMap(
const SparseReinterpretMapOptions &
options) {
71 void runOnOperation()
override {
79 struct PreSparsificationRewritePass
80 :
public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
81 PreSparsificationRewritePass() =
default;
82 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
85 void runOnOperation()
override {
93 struct SparsificationPass
94 :
public impl::SparsificationPassBase<SparsificationPass> {
95 SparsificationPass() =
default;
96 SparsificationPass(
const SparsificationPass &pass) =
default;
98 parallelization =
options.parallelizationStrategy;
99 sparseEmitStrategy =
options.sparseEmitStrategy;
100 enableRuntimeLibrary =
options.enableRuntimeLibrary;
103 void runOnOperation()
override {
107 enableRuntimeLibrary);
111 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
116 struct StageSparseOperationsPass
117 :
public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
118 StageSparseOperationsPass() =
default;
119 StageSparseOperationsPass(
const StageSparseOperationsPass &pass) =
default;
120 void runOnOperation()
override {
128 struct LowerSparseOpsToForeachPass
129 :
public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
130 LowerSparseOpsToForeachPass() =
default;
131 LowerSparseOpsToForeachPass(
const LowerSparseOpsToForeachPass &pass) =
133 LowerSparseOpsToForeachPass(
bool enableRT,
bool convert) {
134 enableRuntimeLibrary = enableRT;
135 enableConvert = convert;
138 void runOnOperation()
override {
147 struct LowerForeachToSCFPass
148 :
public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
149 LowerForeachToSCFPass() =
default;
150 LowerForeachToSCFPass(
const LowerForeachToSCFPass &pass) =
default;
152 void runOnOperation()
override {
160 struct SparseTensorConversionPass
161 :
public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
162 SparseTensorConversionPass() =
default;
163 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
165 void runOnOperation()
override {
171 target.addIllegalDialect<SparseTensorDialect>();
175 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
178 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
181 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
184 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
187 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
188 return converter.
isLegal(op.getSource().getType()) &&
189 converter.
isLegal(op.getDest().getType());
191 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
192 [&](tensor::ExpandShapeOp op) {
193 return converter.
isLegal(op.getSrc().getType()) &&
196 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
197 [&](tensor::CollapseShapeOp op) {
198 return converter.
isLegal(op.getSrc().getType()) &&
201 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
202 [&](bufferization::AllocTensorOp op) {
203 return converter.
isLegal(op.getType());
205 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
206 [&](bufferization::DeallocTensorOp op) {
207 return converter.
isLegal(op.getTensor().getType());
211 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
212 linalg::YieldOp, tensor::ExtractOp,
213 tensor::FromElementsOp>();
214 target.addLegalDialect<
215 arith::ArithDialect, bufferization::BufferizationDialect,
216 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
219 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
226 std::move(patterns))))
231 struct SparseTensorCodegenPass
232 :
public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
233 SparseTensorCodegenPass() =
default;
234 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
235 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
236 createSparseDeallocs = createDeallocs;
237 enableBufferInitialization = enableInit;
240 void runOnOperation()
override {
246 target.addIllegalDialect<SparseTensorDialect>();
247 target.addLegalOp<SortOp>();
248 target.addLegalOp<PushBackOp>();
250 target.addLegalOp<GetStorageSpecifierOp>();
251 target.addLegalOp<SetStorageSpecifierOp>();
252 target.addLegalOp<StorageSpecifierInitOp>();
254 target.addLegalOp<tensor::FromElementsOp>();
259 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
262 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
265 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
268 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
269 [&](bufferization::AllocTensorOp op) {
270 return converter.
isLegal(op.getType());
272 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
273 [&](bufferization::DeallocTensorOp op) {
274 return converter.
isLegal(op.getTensor().getType());
278 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
279 target.addLegalDialect<
280 arith::ArithDialect, bufferization::BufferizationDialect,
281 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
282 target.addLegalOp<UnrealizedConversionCastOp>();
284 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
289 converter, patterns, createSparseDeallocs, enableBufferInitialization);
291 std::move(patterns))))
296 struct SparseBufferRewritePass
297 :
public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
298 SparseBufferRewritePass() =
default;
299 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
300 SparseBufferRewritePass(
bool enableInit) {
301 enableBufferInitialization = enableInit;
304 void runOnOperation()
override {
312 struct SparseVectorizationPass
313 :
public impl::SparseVectorizationBase<SparseVectorizationPass> {
314 SparseVectorizationPass() =
default;
315 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
316 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
318 enableVLAVectorization = vla;
319 enableSIMDIndex32 = sidx32;
322 void runOnOperation()
override {
323 if (vectorLength == 0)
324 return signalPassFailure();
328 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
334 struct SparseGPUCodegenPass
335 :
public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
336 SparseGPUCodegenPass() =
default;
337 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
338 SparseGPUCodegenPass(
unsigned nT,
bool enableRT) {
340 enableRuntimeLibrary = enableRT;
343 void runOnOperation()
override {
354 struct StorageSpecifierToLLVMPass
355 :
public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
356 StorageSpecifierToLLVMPass() =
default;
358 void runOnOperation()
override {
365 target.addIllegalDialect<SparseTensorDialect>();
366 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
369 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
372 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
375 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
377 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
386 std::move(patterns))))
398 return std::make_unique<SparseAssembler>();
402 return std::make_unique<SparseReinterpretMap>();
405 std::unique_ptr<Pass>
407 SparseReinterpretMapOptions
options;
409 return std::make_unique<SparseReinterpretMap>(
options);
413 return std::make_unique<PreSparsificationRewritePass>();
417 return std::make_unique<SparsificationPass>();
420 std::unique_ptr<Pass>
422 return std::make_unique<SparsificationPass>(
options);
426 return std::make_unique<StageSparseOperationsPass>();
430 return std::make_unique<LowerSparseOpsToForeachPass>();
433 std::unique_ptr<Pass>
435 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
439 return std::make_unique<LowerForeachToSCFPass>();
443 return std::make_unique<SparseTensorConversionPass>();
447 return std::make_unique<SparseTensorCodegenPass>();
450 std::unique_ptr<Pass>
452 bool enableBufferInitialization) {
453 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
454 enableBufferInitialization);
458 return std::make_unique<SparseBufferRewritePass>();
461 std::unique_ptr<Pass>
463 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
467 return std::make_unique<SparseVectorizationPass>();
470 std::unique_ptr<Pass>
472 bool enableVLAVectorization,
473 bool enableSIMDIndex32) {
474 return std::make_unique<SparseVectorizationPass>(
475 vectorLength, enableVLAVectorization, enableSIMDIndex32);
479 return std::make_unique<SparseGPUCodegenPass>();
484 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
488 return std::make_unique<StorageSpecifierToLLVMPass>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_range getOperandTypes()
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Type getType() const
Return the type of this value.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
std::unique_ptr< Pass > createStageSparseOperationsPass()
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparsificationPass()
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Options for the Sparsification pass.