MLIR 23.0.0git
Transforms.h
Go to the documentation of this file.
1//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This header declares functions that assist transformations in the MemRef
10/// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
15#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
16
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLFunctionalExtras.h"
19
20namespace mlir {
21class OpBuilder;
23class RewriterBase;
24class Value;
25class ValueRange;
26class ReifyRankedShapedTypeOpInterface;
27
28namespace arith {
31} // namespace arith
32
33namespace memref {
34class AllocOp;
35class AllocaOp;
36class DeallocOp;
37
38//===----------------------------------------------------------------------===//
39// Patterns
40//===----------------------------------------------------------------------===//
41
42/// Collects a set of patterns that bypass memref.reinterpet_cast Ops. This
43/// simplifies the IR in the context of lowering to EmitC.
44void populateElideReinterpretCastPatterns(RewritePatternSet &patterns);
45
46/// Collects a set of patterns to rewrite ops within the memref dialect.
47void populateExpandOpsPatterns(RewritePatternSet &patterns);
48
49/// Appends patterns for folding memref aliasing ops into consumer load/store
50/// ops into `patterns`.
51void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
52
53/// Appends patterns that resolve `memref.dim` operations with values that are
54/// defined by operations that implement the
55/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
56/// operands.
58 RewritePatternSet &patterns);
59
60/// Appends patterns that resolve `memref.dim` operations with values that are
61/// defined by operations that implement the `InferShapedTypeOpInterface`, in
62/// terms of shapes of its input operands.
63void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
64
65/// Appends patterns for expanding memref operations that modify the metadata
66/// (sizes, offset, strides) of a memref into easier to analyze constructs.
67void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
68
69/// Appends patterns for resolving `memref.extract_strided_metadata` into
70/// `memref.extract_strided_metadata` of its source.
71void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns);
72
73/// Appends patterns for expanding `memref.realloc` operations.
74void populateExpandReallocPatterns(RewritePatternSet &patterns,
75 bool emitDeallocs = true);
76
77/// Appends patterns for emulating wide integer memref operations with ops over
78/// narrower integer types.
80 const arith::WideIntEmulationConverter &typeConverter,
81 RewritePatternSet &patterns);
82
83/// Appends type conversions for emulating wide integer memref operations with
84/// ops over narrowe integer types.
86 arith::WideIntEmulationConverter &typeConverter);
87
88/// Appends patterns for emulating memref operations over narrow types with ops
89/// over wider types.
90/// When `disableAtomicRMW` is true, the store patterns generate non-atomic
91/// read-modify-write sequences instead of atomic operations.
92/// When `assumeAligned` is true, `memref.subview` and
93/// `memref.reinterpret_cast` patterns accept dynamic offsets under the
94/// alignment contract that the caller guarantees those offsets are a multiple
95/// of `dstBits / srcBits`. When false (the default), dynamic offsets are
96/// rejected to preserve soundness for callers that cannot prove divisibility.
98 const arith::NarrowTypeEmulationConverter &typeConverter,
99 RewritePatternSet &patterns, bool disableAtomicRMW = false,
100 bool assumeAligned = false);
101
102/// Appends type conversions for emulating memref operations over narrow types
103/// with ops over wider types.
105 arith::NarrowTypeEmulationConverter &typeConverter);
106
107/// Transformation to do multi-buffering/array expansion to remove dependencies
108/// on the temporary allocation between consecutive loop iterations.
109/// It returns the new allocation if the original allocation was multi-buffered
110/// and returns failure() otherwise.
111/// When `skipOverrideAnalysis`, the pass will apply the transformation
112/// without checking thwt the buffer is overrided at the beginning of each
113/// iteration. This implies that user knows that there is no data carried across
114/// loop iterations. Example:
115/// ```
116/// %0 = memref.alloc() : memref<4x128xf32>
117/// scf.for %iv = %c1 to %c1024 step %c3 {
118/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
119/// "some_use"(%0) : (memref<4x128xf32>) -> ()
120/// }
121/// ```
122/// into:
123/// ```
124/// %0 = memref.alloc() : memref<5x4x128xf32>
125/// scf.for %iv = %c1 to %c1024 step %c3 {
126/// %s = arith.subi %iv, %c1 : index
127/// %d = arith.divsi %s, %c3 : index
128/// %i = arith.remsi %d, %c5 : index
129/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
130/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
131/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
132/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
133/// }
134/// ```
135FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
136 memref::AllocOp allocOp,
137 unsigned multiplier,
138 bool skipOverrideAnalysis = false);
139/// Call into `multiBuffer` with locally constructed IRRewriter.
140FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
141 unsigned multiplier,
142 bool skipOverrideAnalysis = false);
143
144/// Appends patterns for extracting address computations from the instructions
145/// with memory accesses such that these memory accesses use only a base
146/// pointer.
147///
148/// For instance,
149/// ```mlir
150/// memref.load %base[%off0, ...]
151/// ```
152///
153/// Will be rewritten in:
154/// ```mlir
155/// %new_base = memref.subview %base[%off0,...][1,...][1,...]
156/// memref.load %new_base[%c0,...]
157/// ```
158void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
159
160/// Patterns for flattening multi-dimensional memref operations into
161/// one-dimensional memref operations.
162void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
163void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
164void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
165
166/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
167/// given independencies. If the op is already independent of all
168/// independencies, the same AllocaOp result is returned.
169///
170/// Failure indicates the no suitable upper bound for the dynamic sizes could be
171/// found.
172FailureOr<Value> buildIndependentOp(OpBuilder &b, AllocaOp allocaOp,
173 ValueRange independencies);
174
175/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
176/// given independencies. If the op is already independent of all
177/// independencies, the same AllocaOp result is returned.
178///
179/// The original AllocaOp is replaced with the new one, wrapped in a SubviewOp.
180/// The result type of the replacement is different from the original allocation
181/// type: it has the same shape, but a different layout map. This function
182/// updates all users that do not have a memref result or memref region block
183/// argument, and some frequently used memref dialect ops (such as
184/// memref.subview). It does not update other uses such as the init_arg of an
185/// scf.for op. Such uses are wrapped in unrealized_conversion_cast.
186///
187/// Failure indicates the no suitable upper bound for the dynamic sizes could be
188/// found.
189///
190/// Example (make independent of %iv):
191/// ```
192/// scf.for %iv = %c0 to %sz step %c1 {
193/// %0 = memref.alloca(%iv) : memref<?xf32>
194/// %1 = memref.subview %0[0][5][1] : ...
195/// linalg.generic outs(%1 : ...) ...
196/// %2 = scf.for ... iter_arg(%arg0 = %0) ...
197/// ...
198/// }
199/// ```
200///
201/// The above IR is rewritten to:
202///
203/// ```
204/// scf.for %iv = %c0 to %sz step %c1 {
205/// %0 = memref.alloca(%sz - 1) : memref<?xf32>
206/// %0_subview = memref.subview %0[0][%iv][1]
207/// : memref<?xf32> to memref<?xf32, #map>
208/// %1 = memref.subview %0_subview[0][5][1] : ...
209/// linalg.generic outs(%1 : ...) ...
210/// %cast = unrealized_conversion_cast %0_subview
211/// : memref<?xf32, #map> to memref<?xf32>
212/// %2 = scf.for ... iter_arg(%arg0 = %cast) ...
213/// ...
214/// }
215/// ```
216FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
217 memref::AllocaOp allocaOp,
218 ValueRange independencies);
219
220/// Replaces the given `alloc` with the corresponding `alloca` and returns it if
221/// the following conditions are met:
222/// - the corresponding dealloc is available in the same block as the alloc;
223/// - the filter, if provided, succeeds on the alloc/dealloc pair.
224/// Otherwise returns nullptr and leaves the IR unchanged.
225memref::AllocaOp allocToAlloca(
226 RewriterBase &rewriter, memref::AllocOp alloc,
227 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
228} // namespace memref
229} // namespace mlir
230
231#endif
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class helps build Operations.
Definition Builders.h:209
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Converts narrow integer or float types that are not supported by the target hardware to wider types.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateMemRefWideIntEmulationPatterns(const arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating wide integer memref operations with ops over narrower integer types.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false, bool assumeAligned=false)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
void populateElideReinterpretCastPatterns(RewritePatternSet &patterns)
Collects a set of patterns that bypass memref.reinterpet_cast Ops.
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateMemRefWideIntEmulationConversions(arith::WideIntEmulationConverter &typeConverter)
Appends type conversions for emulating wide integer memref operations with ops over narrowe integer t...
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns)
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
FailureOr< Value > buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147