MLIR  17.0.0git
Passes.h
Go to the documentation of this file.
1 //===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header declares patterns and passes on MemRef operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
14 #define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
15 
16 #include "mlir/Pass/Pass.h"
17 
18 namespace mlir {
19 
20 class AffineDialect;
21 class ModuleOp;
22 
23 namespace arith {
24 class WideIntEmulationConverter;
25 } // namespace arith
26 
27 namespace func {
28 class FuncDialect;
29 } // namespace func
30 namespace tensor {
31 class TensorDialect;
32 } // namespace tensor
33 namespace vector {
34 class VectorDialect;
35 } // namespace vector
36 
37 namespace memref {
38 class AllocOp;
39 //===----------------------------------------------------------------------===//
40 // Patterns
41 //===----------------------------------------------------------------------===//
42 
43 /// Collects a set of patterns to rewrite ops within the memref dialect.
45 
46 /// Appends patterns for folding memref aliasing ops into consumer load/store
47 /// ops into `patterns`.
49 
50 /// Appends patterns that resolve `memref.dim` operations with values that are
51 /// defined by operations that implement the
52 /// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
53 /// operands.
55  RewritePatternSet &patterns);
56 
57 /// Appends patterns that resolve `memref.dim` operations with values that are
58 /// defined by operations that implement the `InferShapedTypeOpInterface`, in
59 /// terms of shapes of its input operands.
61 
62 /// Appends patterns for expanding memref operations that modify the metadata
63 /// (sizes, offset, strides) of a memref into easier to analyze constructs.
65 
66 /// Appends patterns for emulating wide integer memref operations with ops over
67 /// narrower integer types.
69  arith::WideIntEmulationConverter &typeConverter,
70  RewritePatternSet &patterns);
71 
72 /// Appends type converions for emulating wide integer memref operations with
73 /// ops over narrowe integer types.
75  arith::WideIntEmulationConverter &typeConverter);
76 
77 /// Transformation to do multi-buffering/array expansion to remove dependencies
78 /// on the temporary allocation between consecutive loop iterations.
79 /// It returns the new allocation if the original allocation was multi-buffered
80 /// and returns failure() otherwise.
81 /// Example:
82 /// ```
83 /// %0 = memref.alloc() : memref<4x128xf32>
84 /// scf.for %iv = %c1 to %c1024 step %c3 {
85 /// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
86 /// "some_use"(%0) : (memref<4x128xf32>) -> ()
87 /// }
88 /// ```
89 /// into:
90 /// ```
91 /// %0 = memref.alloc() : memref<5x4x128xf32>
92 /// scf.for %iv = %c1 to %c1024 step %c3 {
93 /// %s = arith.subi %iv, %c1 : index
94 /// %d = arith.divsi %s, %c3 : index
95 /// %i = arith.remsi %d, %c5 : index
96 /// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
97 /// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
98 /// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
99 /// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
100 /// }
101 /// ```
102 FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
103  unsigned multiplier);
104 
105 //===----------------------------------------------------------------------===//
106 // Passes
107 //===----------------------------------------------------------------------===//
108 
109 #define GEN_PASS_DECL
110 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
111 
112 /// Creates an instance of the ExpandOps pass that legalizes memref dialect ops
113 /// to be convertible to LLVM. For example, `memref.reshape` gets converted to
114 /// `memref_reinterpret_cast`.
115 std::unique_ptr<Pass> createExpandOpsPass();
116 
117 /// Creates an operation pass to fold memref aliasing ops into consumer
118 /// load/store ops into `patterns`.
119 std::unique_ptr<Pass> createFoldMemRefAliasOpsPass();
120 
121 /// Creates an interprocedural pass to normalize memrefs to have a trivial
122 /// (identity) layout map.
123 std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
124 
125 /// Creates an operation pass to resolve `memref.dim` operations with values
126 /// that are defined by operations that implement the
127 /// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
128 /// operands.
129 std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
130 
131 /// Creates an operation pass to resolve `memref.dim` operations with values
132 /// that are defined by operations that implement the
133 /// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
134 /// in terms of shapes of its input operands.
135 std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
136 
137 /// Creates an operation pass to expand some memref operation into
138 /// easier to reason about operations.
139 std::unique_ptr<Pass> createExpandStridedMetadataPass();
140 
141 //===----------------------------------------------------------------------===//
142 // Registration
143 //===----------------------------------------------------------------------===//
144 
145 #define GEN_PASS_REGISTRATION
146 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
147 
148 } // namespace memref
149 } // namespace mlir
150 
151 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
std::unique_ptr< Pass > createFoldMemRefAliasOpsPass()
Creates an operation pass to fold memref aliasing ops into consumer load/store ops into patterns.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
FailureOr< memref::AllocOp > multiBuffer(memref::AllocOp allocOp, unsigned multiplier)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:80
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveRankedShapeTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateMemRefWideIntEmulationPatterns(arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating wide integer memref operations with ops over narrower integer types.
void populateMemRefWideIntEmulationConversions(arith::WideIntEmulationConverter &typeConverter)
Appends type converions for emulating wide integer memref operations with ops over narrowe integer ty...
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
std::unique_ptr< OperationPass< ModuleOp > > createNormalizeMemRefsPass()
Creates an interprocedural pass to normalize memrefs to have a trivial (identity) layout map.
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
Definition: ExpandOps.cpp:158
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:153
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.