25 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
26 #define GEN_PASS_DEF_SPARSIFICATIONPASS
27 #define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
29 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
30 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
31 #define GEN_PASS_DEF_SPARSEVECTORIZATION
32 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
33 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
34 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
46 struct PreSparsificationRewritePass
47 :
public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
49 PreSparsificationRewritePass() =
default;
50 PreSparsificationRewritePass(
const PreSparsificationRewritePass &pass) =
53 void runOnOperation()
override {
61 struct SparsificationPass
62 :
public impl::SparsificationPassBase<SparsificationPass> {
64 SparsificationPass() =
default;
65 SparsificationPass(
const SparsificationPass &pass) =
default;
67 parallelization =
options.parallelizationStrategy;
68 gpuDataTransfer =
options.gpuDataTransferStrategy;
69 enableIndexReduction =
options.enableIndexReduction;
70 enableGPULibgen =
options.enableGPULibgen;
71 enableRuntimeLibrary =
options.enableRuntimeLibrary;
74 void runOnOperation()
override {
78 enableIndexReduction, enableGPULibgen,
79 enableRuntimeLibrary);
82 if (enableGPULibgen) {
85 "zero-copy transfer not supported with GPU libgen");
90 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
95 struct PostSparsificationRewritePass
96 :
public impl::PostSparsificationRewriteBase<
97 PostSparsificationRewritePass> {
99 PostSparsificationRewritePass() =
default;
100 PostSparsificationRewritePass(
const PostSparsificationRewritePass &pass) =
102 PostSparsificationRewritePass(
bool enableRT,
bool foreach,
bool convert) {
103 enableRuntimeLibrary = enableRT;
104 enableForeach =
foreach;
105 enableConvert = convert;
108 void runOnOperation()
override {
112 enableForeach, enableConvert);
117 struct SparseTensorConversionPass
118 :
public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
120 SparseTensorConversionPass() =
default;
121 SparseTensorConversionPass(
const SparseTensorConversionPass &pass) =
default;
123 sparseToSparse =
static_cast<int32_t
>(
options.sparseToSparseStrategy);
126 void runOnOperation()
override {
132 target.addIllegalDialect<SparseTensorDialect>();
136 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
139 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
142 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
145 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
148 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
149 return converter.
isLegal(op.getSource().getType()) &&
150 converter.
isLegal(op.getDest().getType());
152 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
153 [&](tensor::ExpandShapeOp op) {
154 return converter.
isLegal(op.getSrc().getType()) &&
157 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
158 [&](tensor::CollapseShapeOp op) {
159 return converter.
isLegal(op.getSrc().getType()) &&
162 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
163 [&](bufferization::AllocTensorOp op) {
164 return converter.
isLegal(op.getType());
166 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
167 [&](bufferization::DeallocTensorOp op) {
168 return converter.
isLegal(op.getTensor().getType());
172 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
173 linalg::YieldOp, tensor::ExtractOp>();
174 target.addLegalDialect<
175 arith::ArithDialect, bufferization::BufferizationDialect,
176 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
181 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
188 std::move(patterns))))
193 struct SparseTensorCodegenPass
194 :
public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
196 SparseTensorCodegenPass() =
default;
197 SparseTensorCodegenPass(
const SparseTensorCodegenPass &pass) =
default;
198 SparseTensorCodegenPass(
bool createDeallocs,
bool enableInit) {
199 createSparseDeallocs = createDeallocs;
200 enableBufferInitialization = enableInit;
203 void runOnOperation()
override {
209 target.addIllegalDialect<SparseTensorDialect>();
210 target.addLegalOp<SortCooOp>();
211 target.addLegalOp<PushBackOp>();
213 target.addLegalOp<GetStorageSpecifierOp>();
214 target.addLegalOp<SetStorageSpecifierOp>();
215 target.addLegalOp<StorageSpecifierInitOp>();
217 target.addLegalOp<tensor::FromElementsOp>();
222 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
225 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
228 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
231 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
232 [&](bufferization::AllocTensorOp op) {
233 return converter.
isLegal(op.getType());
235 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
236 [&](bufferization::DeallocTensorOp op) {
237 return converter.
isLegal(op.getTensor().getType());
241 target.addLegalOp<linalg::FillOp>();
242 target.addLegalDialect<
243 arith::ArithDialect, bufferization::BufferizationDialect,
244 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
245 target.addLegalOp<UnrealizedConversionCastOp>();
247 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
252 converter, patterns, createSparseDeallocs, enableBufferInitialization);
254 std::move(patterns))))
259 struct SparseBufferRewritePass
260 :
public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
262 SparseBufferRewritePass() =
default;
263 SparseBufferRewritePass(
const SparseBufferRewritePass &pass) =
default;
264 SparseBufferRewritePass(
bool enableInit) {
265 enableBufferInitialization = enableInit;
268 void runOnOperation()
override {
276 struct SparseVectorizationPass
277 :
public impl::SparseVectorizationBase<SparseVectorizationPass> {
279 SparseVectorizationPass() =
default;
280 SparseVectorizationPass(
const SparseVectorizationPass &pass) =
default;
281 SparseVectorizationPass(
unsigned vl,
bool vla,
bool sidx32) {
283 enableVLAVectorization = vla;
284 enableSIMDIndex32 = sidx32;
287 void runOnOperation()
override {
288 if (vectorLength == 0)
289 return signalPassFailure();
293 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
299 struct SparseGPUCodegenPass
300 :
public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
302 SparseGPUCodegenPass() =
default;
303 SparseGPUCodegenPass(
const SparseGPUCodegenPass &pass) =
default;
304 SparseGPUCodegenPass(
unsigned nT) { numThreads = nT; }
306 void runOnOperation()
override {
314 struct StorageSpecifierToLLVMPass
315 :
public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
317 StorageSpecifierToLLVMPass() =
default;
319 void runOnOperation()
override {
326 target.addIllegalDialect<SparseTensorDialect>();
327 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
330 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
333 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
336 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
338 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
347 std::move(patterns))))
375 return std::make_unique<PreSparsificationRewritePass>();
379 return std::make_unique<SparsificationPass>();
382 std::unique_ptr<Pass>
384 return std::make_unique<SparsificationPass>(
options);
388 return std::make_unique<PostSparsificationRewritePass>();
391 std::unique_ptr<Pass>
393 bool enableConvert) {
394 return std::make_unique<PostSparsificationRewritePass>(
395 enableRT, enableForeach, enableConvert);
399 return std::make_unique<SparseTensorConversionPass>();
404 return std::make_unique<SparseTensorConversionPass>(
options);
408 return std::make_unique<SparseTensorCodegenPass>();
411 std::unique_ptr<Pass>
413 bool enableBufferInitialization) {
414 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
415 enableBufferInitialization);
419 return std::make_unique<SparseBufferRewritePass>();
422 std::unique_ptr<Pass>
424 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
428 return std::make_unique<SparseVectorizationPass>();
431 std::unique_ptr<Pass>
433 bool enableVLAVectorization,
434 bool enableSIMDIndex32) {
435 return std::make_unique<SparseVectorizationPass>(
436 vectorLength, enableVLAVectorization, enableSIMDIndex32);
440 return std::make_unique<SparseGPUCodegenPass>();
444 return std::make_unique<SparseGPUCodegenPass>(numThreads);
448 return std::make_unique<StorageSpecifierToLLVMPass>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_range getOperandTypes()
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Type getType() const
Return the type of this value.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
This header declares functions that assist transformations in the MemRef dialect.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createSparseGPUCodegenPass()
SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag)
Converts command-line sparse2sparse flag to the strategy enum.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT, GPUDataTransferStrategy gpuDataTransfer)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns)
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, const SparseTensorConversionOptions &options=SparseTensorConversionOptions())
Sets up sparse tensor conversion rules.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
SparseToSparseConversionStrategy
Defines a strategy for implementing sparse-to-sparse conversion.
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparsificationPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void populatePostSparsificationRewriting(RewritePatternSet &patterns, bool enableRT, bool enableForeach, bool enableConvert)
std::unique_ptr< Pass > createPostSparsificationRewritePass()
SparseTensorConversion options.
Options for the Sparsification pass.