MLIR 23.0.0git
Transforms.h
Go to the documentation of this file.
1//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This header declares functions that assist transformations in the MemRef
10/// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
15#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
16
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLFunctionalExtras.h"
19
20namespace mlir {
21class OpBuilder;
23class RewriterBase;
24class Value;
25class ValueRange;
26class ReifyRankedShapedTypeOpInterface;
27
28namespace arith {
31} // namespace arith
32
33namespace memref {
34class AllocOp;
35class AllocaOp;
36class DeallocOp;
37
38//===----------------------------------------------------------------------===//
39// Patterns
40//===----------------------------------------------------------------------===//
41
42/// Collects a set of patterns that bypass memref.reinterpet_cast Ops. This
43/// simplifies the IR in the context of lowering to EmitC.
44void populateElideReinterpretCastPatterns(RewritePatternSet &patterns);
45
46/// Collects a set of patterns to rewrite ops within the memref dialect.
47void populateExpandOpsPatterns(RewritePatternSet &patterns);
48
49/// Appends patterns for folding memref aliasing ops into consumer load/store
50/// ops into `patterns`.
51void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
52
53/// Appends patterns that resolve `memref.dim` operations with values that are
54/// defined by operations that implement the
55/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
56/// operands.
58 RewritePatternSet &patterns);
59
60/// Appends patterns that resolve `memref.dim` operations with values that are
61/// defined by operations that implement the `InferShapedTypeOpInterface`, in
62/// terms of shapes of its input operands.
63void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
64
65/// Appends patterns for expanding memref operations that modify the metadata
66/// (sizes, offset, strides) of a memref into easier to analyze constructs.
67void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
68
69/// Appends patterns for resolving `memref.extract_strided_metadata` into
70/// `memref.extract_strided_metadata` of its source.
71void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns);
72
73/// Appends patterns for expanding `memref.realloc` operations.
74void populateExpandReallocPatterns(RewritePatternSet &patterns,
75 bool emitDeallocs = true);
76
77/// Appends patterns for emulating wide integer memref operations with ops over
78/// narrower integer types.
80 const arith::WideIntEmulationConverter &typeConverter,
81 RewritePatternSet &patterns);
82
83/// Appends type conversions for emulating wide integer memref operations with
84/// ops over narrowe integer types.
86 arith::WideIntEmulationConverter &typeConverter);
87
88/// Appends patterns for emulating memref operations over narrow types with ops
89/// over wider types.
90/// When `disableAtomicRMW` is true, the store patterns generate non-atomic
91/// read-modify-write sequences instead of atomic operations.
93 const arith::NarrowTypeEmulationConverter &typeConverter,
94 RewritePatternSet &patterns, bool disableAtomicRMW = false);
95
96/// Appends type conversions for emulating memref operations over narrow types
97/// with ops over wider types.
99 arith::NarrowTypeEmulationConverter &typeConverter);
100
101/// Transformation to do multi-buffering/array expansion to remove dependencies
102/// on the temporary allocation between consecutive loop iterations.
103/// It returns the new allocation if the original allocation was multi-buffered
104/// and returns failure() otherwise.
105/// When `skipOverrideAnalysis`, the pass will apply the transformation
106/// without checking thwt the buffer is overrided at the beginning of each
107/// iteration. This implies that user knows that there is no data carried across
108/// loop iterations. Example:
109/// ```
110/// %0 = memref.alloc() : memref<4x128xf32>
111/// scf.for %iv = %c1 to %c1024 step %c3 {
112/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
113/// "some_use"(%0) : (memref<4x128xf32>) -> ()
114/// }
115/// ```
116/// into:
117/// ```
118/// %0 = memref.alloc() : memref<5x4x128xf32>
119/// scf.for %iv = %c1 to %c1024 step %c3 {
120/// %s = arith.subi %iv, %c1 : index
121/// %d = arith.divsi %s, %c3 : index
122/// %i = arith.remsi %d, %c5 : index
123/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
124/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
125/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
126/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
127/// }
128/// ```
129FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
130 memref::AllocOp allocOp,
131 unsigned multiplier,
132 bool skipOverrideAnalysis = false);
133/// Call into `multiBuffer` with locally constructed IRRewriter.
134FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
135 unsigned multiplier,
136 bool skipOverrideAnalysis = false);
137
138/// Appends patterns for extracting address computations from the instructions
139/// with memory accesses such that these memory accesses use only a base
140/// pointer.
141///
142/// For instance,
143/// ```mlir
144/// memref.load %base[%off0, ...]
145/// ```
146///
147/// Will be rewritten in:
148/// ```mlir
149/// %new_base = memref.subview %base[%off0,...][1,...][1,...]
150/// memref.load %new_base[%c0,...]
151/// ```
152void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
153
154/// Patterns for flattening multi-dimensional memref operations into
155/// one-dimensional memref operations.
156void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
157void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
158void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
159
160/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
161/// given independencies. If the op is already independent of all
162/// independencies, the same AllocaOp result is returned.
163///
164/// Failure indicates the no suitable upper bound for the dynamic sizes could be
165/// found.
166FailureOr<Value> buildIndependentOp(OpBuilder &b, AllocaOp allocaOp,
167 ValueRange independencies);
168
169/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
170/// given independencies. If the op is already independent of all
171/// independencies, the same AllocaOp result is returned.
172///
173/// The original AllocaOp is replaced with the new one, wrapped in a SubviewOp.
174/// The result type of the replacement is different from the original allocation
175/// type: it has the same shape, but a different layout map. This function
176/// updates all users that do not have a memref result or memref region block
177/// argument, and some frequently used memref dialect ops (such as
178/// memref.subview). It does not update other uses such as the init_arg of an
179/// scf.for op. Such uses are wrapped in unrealized_conversion_cast.
180///
181/// Failure indicates the no suitable upper bound for the dynamic sizes could be
182/// found.
183///
184/// Example (make independent of %iv):
185/// ```
186/// scf.for %iv = %c0 to %sz step %c1 {
187/// %0 = memref.alloca(%iv) : memref<?xf32>
188/// %1 = memref.subview %0[0][5][1] : ...
189/// linalg.generic outs(%1 : ...) ...
190/// %2 = scf.for ... iter_arg(%arg0 = %0) ...
191/// ...
192/// }
193/// ```
194///
195/// The above IR is rewritten to:
196///
197/// ```
198/// scf.for %iv = %c0 to %sz step %c1 {
199/// %0 = memref.alloca(%sz - 1) : memref<?xf32>
200/// %0_subview = memref.subview %0[0][%iv][1]
201/// : memref<?xf32> to memref<?xf32, #map>
202/// %1 = memref.subview %0_subview[0][5][1] : ...
203/// linalg.generic outs(%1 : ...) ...
204/// %cast = unrealized_conversion_cast %0_subview
205/// : memref<?xf32, #map> to memref<?xf32>
206/// %2 = scf.for ... iter_arg(%arg0 = %cast) ...
207/// ...
208/// }
209/// ```
210FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
211 memref::AllocaOp allocaOp,
212 ValueRange independencies);
213
214/// Replaces the given `alloc` with the corresponding `alloca` and returns it if
215/// the following conditions are met:
216/// - the corresponding dealloc is available in the same block as the alloc;
217/// - the filter, if provided, succeeds on the alloc/dealloc pair.
218/// Otherwise returns nullptr and leaves the IR unchanged.
219memref::AllocaOp allocToAlloca(
220 RewriterBase &rewriter, memref::AllocOp alloc,
221 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
222} // namespace memref
223} // namespace mlir
224
225#endif
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class helps build Operations.
Definition Builders.h:209
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Converts narrow integer or float types that are not supported by the target hardware to wider types.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateMemRefWideIntEmulationPatterns(const arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating wide integer memref operations with ops over narrower integer types.
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
void populateElideReinterpretCastPatterns(RewritePatternSet &patterns)
Collects a set of patterns that bypass memref.reinterpet_cast Ops.
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateMemRefWideIntEmulationConversions(arith::WideIntEmulationConverter &typeConverter)
Appends type conversions for emulating wide integer memref operations with ops over narrowe integer t...
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns)
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
FailureOr< Value > buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147