MLIR  16.0.0git
Passes.h
Go to the documentation of this file.
1 //===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header declares patterns and passes on MemRef operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
14 #define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
15 
16 #include "mlir/Pass/Pass.h"
17 
18 namespace mlir {
19 
20 class AffineDialect;
21 class ModuleOp;
22 
23 namespace func {
24 class FuncDialect;
25 } // namespace func
26 namespace tensor {
27 class TensorDialect;
28 } // namespace tensor
29 namespace vector {
30 class VectorDialect;
31 } // namespace vector
32 
33 namespace memref {
34 class AllocOp;
35 //===----------------------------------------------------------------------===//
36 // Patterns
37 //===----------------------------------------------------------------------===//
38 
39 /// Collects a set of patterns to rewrite ops within the memref dialect.
41 
42 /// Appends patterns for folding memref.subview ops into consumer load/store ops
43 /// into `patterns`.
45 
46 /// Appends patterns that resolve `memref.dim` operations with values that are
47 /// defined by operations that implement the
48 /// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
49 /// operands.
51  RewritePatternSet &patterns);
52 
53 /// Appends patterns that resolve `memref.dim` operations with values that are
54 /// defined by operations that implement the `InferShapedTypeOpInterface`, in
55 /// terms of shapes of its input operands.
57 
58 /// Transformation to do multi-buffering/array expansion to remove dependencies
59 /// on the temporary allocation between consecutive loop iterations.
60 /// It return success if the allocation was multi-buffered and returns failure()
61 /// otherwise.
62 /// Example:
63 /// ```
64 /// %0 = memref.alloc() : memref<4x128xf32>
65 /// scf.for %iv = %c1 to %c1024 step %c3 {
66 /// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
67 /// "some_use"(%0) : (memref<4x128xf32>) -> ()
68 /// }
69 /// ```
70 /// into:
71 /// ```
72 /// %0 = memref.alloc() : memref<5x4x128xf32>
73 /// scf.for %iv = %c1 to %c1024 step %c3 {
74 /// %s = arith.subi %iv, %c1 : index
75 /// %d = arith.divsi %s, %c3 : index
76 /// %i = arith.remsi %d, %c5 : index
77 /// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
78 /// memref<5x4x128xf32> to memref<4x128xf32, #map0>
79 /// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, #map0>
80 /// "some_use"(%sv) : (memref<4x128xf32, $map0>) -> ()
81 /// }
82 /// ```
83 LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
84 
85 //===----------------------------------------------------------------------===//
86 // Passes
87 //===----------------------------------------------------------------------===//
88 
89 /// Creates an instance of the ExpandOps pass that legalizes memref dialect ops
90 /// to be convertible to LLVM. For example, `memref.reshape` gets converted to
91 /// `memref_reinterpret_cast`.
92 std::unique_ptr<Pass> createExpandOpsPass();
93 
94 /// Creates an operation pass to fold memref.subview ops into consumer
95 /// load/store ops into `patterns`.
96 std::unique_ptr<Pass> createFoldSubViewOpsPass();
97 
98 /// Creates an interprocedural pass to normalize memrefs to have a trivial
99 /// (identity) layout map.
100 std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
101 
102 /// Creates an operation pass to resolve `memref.dim` operations with values
103 /// that are defined by operations that implement the
104 /// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
105 /// operands.
106 std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
107 
108 /// Creates an operation pass to resolve `memref.dim` operations with values
109 /// that are defined by operations that implement the
110 /// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
111 /// in terms of shapes of its input operands.
112 std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
113 
114 //===----------------------------------------------------------------------===//
115 // Registration
116 //===----------------------------------------------------------------------===//
117 
118 #define GEN_PASS_REGISTRATION
119 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
120 
121 } // namespace memref
122 } // namespace mlir
123 
124 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
Include the generated interface declarations.
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:147
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref.subview ops into consumer load/store ops into patterns...
void populateResolveRankedShapeTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
std::unique_ptr< Pass > createFoldSubViewOpsPass()
Creates an operation pass to fold memref.subview ops into consumer load/store ops into patterns...
std::unique_ptr< OperationPass< ModuleOp > > createNormalizeMemRefsPass()
Creates an interprocedural pass to normalize memrefs to have a trivial (identity) layout map...
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:81
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
Definition: ExpandOps.cpp:152
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...