MLIR  19.0.0git
VectorUtils.h
Go to the documentation of this file.
1 //===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
10 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
11 
17 #include "mlir/Support/LLVM.h"
18 
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class AffineMap;
26 class Block;
27 class Location;
28 class OpBuilder;
29 class Operation;
30 class ShapedType;
31 class Value;
32 class VectorType;
33 class VectorTransferOpInterface;
34 
35 namespace affine {
36 class AffineApplyOp;
37 class AffineForOp;
38 } // namespace affine
39 
40 namespace vector {
41 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
42 /// the type of `source`.
43 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
44 
45 /// Returns two dims that are greater than one if the transposition is applied
46 /// on a 2D slice. Otherwise, returns a failure.
47 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
48 
49 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
50 ///
51 /// Only the N = vectorType.getRank() trailing dims of `memrefType` are
52 /// checked (the other dims are not relevant). Note that for `vectorType` to be
53 /// a contiguous slice of `memrefType`, the trailing dims of the latter have
54 /// to be contiguous - this is checked by looking at the corresponding strides.
55 ///
56 /// There might be some restriction on the leading dim of `VectorType`:
57 ///
58 /// Case 1. If all the trailing dims of `vectorType` match the trailing dims
59 /// of `memrefType` then the leading dim of `vectorType` can be
60 /// arbitrary.
61 ///
62 /// Ex. 1.1 contiguous slice, perfect match
63 /// vector<4x3x2xi32> from memref<5x4x3x2xi32>
64 /// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
65 /// vector<2x3x2xi32> from memref<5x4x3x2xi32>
66 ///
67 /// Case 2. If an "internal" dim of `vectorType` does not match the
68 /// corresponding trailing dim in `memrefType` then the remaining
69 /// leading dims of `vectorType` have to be 1 (the first non-matching
70 /// dim can be arbitrary).
71 ///
72 /// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
73 /// vector<2x2x2xi32> from memref<5x4x3x2xi32>
74 /// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
75 /// vector<1x2x2xi32> from memref<5x4x3x2xi32>
76 /// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
77 /// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
78 /// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
79 /// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
80 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
81 
82 /// Returns an iterator for all positions in the leading dimensions of `vType`
83 /// up to the `targetRank`. If any leading dimension before the `targetRank` is
84 /// scalable (so cannot be unrolled), it will return an iterator for positions
85 /// up to the first scalable dimension.
86 ///
87 /// If no leading dimensions can be unrolled an empty optional will be returned.
88 ///
89 /// Examples:
90 ///
91 /// For vType = vector<2x3x4> and targetRank = 1
92 ///
93 /// The resulting iterator will yield:
94 /// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
95 ///
96 /// For vType = vector<3x[4]x5> and targetRank = 0
97 ///
98 /// The scalable dimension blocks unrolling so the iterator yields only:
99 /// [0], [1], [2]
100 ///
101 std::optional<StaticTileOffsetRange>
102 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
103 
104 /// A wrapper for getMixedSizes for vector.transfer_read and
105 /// vector.transfer_write Ops (for source and destination, respectively).
106 ///
107 /// Tensor and MemRef types implement their own, very similar version of
108 /// getMixedSizes. This method will call the appropriate version (depending on
109 /// `hasTensorSemantics`). It will also automatically extract the operand for
110 /// which to call it on (source for "read" and destination for "write" ops).
111 SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
112  Operation *xfer,
113  RewriterBase &rewriter);
114 
115 /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
116 /// masked (i.e. inside `vector.mask` Op region). In particular:
117 /// 1. Matches `SourceOp` operation, Op.
118 /// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
119 /// insertion point to avoid inserting new ops into the `vector.mask` Op
120 /// region (which only allows one Op).
121 /// 2.2 If Op is not masked, this step is skipped.
122 /// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
123 /// found in step 2.1.
124 ///
125 /// This wrapper frees patterns from re-implementing the logic to update the
126 /// insertion point when a maskable Op is masked. Such patterns are still
127 /// responsible for providing an updated ("rewritten") version of:
128 /// a. the source Op when mask _is not_ present,
129 /// b. the source Op and the masking Op when mask _is_ present.
130 /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
131 /// the return value will depend on the case above.
132 template <class SourceOp>
135 
136 private:
137  LogicalResult matchAndRewrite(SourceOp sourceOp,
138  PatternRewriter &rewriter) const final {
139  auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
140  if (!maskableOp)
141  return failure();
142 
143  Operation *rootOp = sourceOp;
144 
145  // If this Op is masked, update the insertion point to avoid inserting into
146  // the vector.mask Op region.
147  OpBuilder::InsertionGuard guard(rewriter);
148  MaskingOpInterface maskOp;
149  if (maskableOp.isMasked()) {
150  maskOp = maskableOp.getMaskingOp();
151  rewriter.setInsertionPoint(maskOp);
152  rootOp = maskOp;
153  }
154 
155  FailureOr<Value> newOp =
156  matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
157  if (failed(newOp))
158  return failure();
159 
160  rewriter.replaceOp(rootOp, *newOp);
161  return success();
162  }
163 
164 public:
165  // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
166  // latter is present, returns a replacement for `maskingOp`. Otherwise,
167  // returns a replacement for `sourceOp`.
168  virtual FailureOr<Value>
169  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
170  PatternRewriter &rewriter) const = 0;
171 };
172 
173 /// Returns true if the input Vector type can be linearized.
174 ///
175 /// Linearization is meant in the sense of flattening vectors, e.g.:
176 /// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
177 /// In this sense, Vectors that are either:
178 /// * already linearized, or
179 /// * contain more than 1 scalable dimensions,
180 /// are not linearizable.
181 bool isLinearizableVector(VectorType type);
182 
183 } // namespace vector
184 
185 /// Constructs a permutation map of invariant memref indices to vector
186 /// dimension.
187 ///
188 /// If no index is found to be invariant, 0 is added to the permutation_map and
189 /// corresponds to a vector broadcast along that dimension.
190 ///
191 /// The implementation uses the knowledge of the mapping of loops to
192 /// vector dimension. `loopToVectorDim` carries this information as a map with:
193 /// - keys representing "vectorized enclosing loops";
194 /// - values representing the corresponding vector dimension.
195 /// Note that loopToVectorDim is a whole function map from which only enclosing
196 /// loop information is extracted.
197 ///
198 /// Prerequisites: `indices` belong to a vectorizable load or store operation
199 /// (i.e. at most one invariant index along each AffineForOp of
200 /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
201 /// load or store operation.
202 ///
203 /// Example 1:
204 /// The following MLIR snippet:
205 ///
206 /// ```mlir
207 /// affine.for %i3 = 0 to %0 {
208 /// affine.for %i4 = 0 to %1 {
209 /// affine.for %i5 = 0 to %2 {
210 /// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
211 /// }}}
212 /// ```
213 ///
214 /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
215 ///
216 /// ```mlir
217 /// affine.for %i3 = 0 to %0 step 32 {
218 /// affine.for %i4 = 0 to %1 {
219 /// affine.for %i5 = 0 to %2 step 256 {
220 /// %4 = vector.transfer_read %arg0, %i4, %i5, %i3
221 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
222 /// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
223 /// }}}
224 /// ```
225 ///
226 /// Meaning that vector.transfer_read will be responsible for reading the slice:
227 /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
228 ///
229 /// Example 2:
230 /// The following MLIR snippet:
231 ///
232 /// ```mlir
233 /// %cst0 = arith.constant 0 : index
234 /// affine.for %i0 = 0 to %0 {
235 /// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
236 /// }
237 /// ```
238 ///
239 /// may vectorize with {permutation_map: (d0) -> (0)} into:
240 ///
241 /// ```mlir
242 /// affine.for %i0 = 0 to %0 step 128 {
243 /// %3 = vector.transfer_read %arg0, %c0_0, %c0_0
244 /// {permutation_map: (d0, d1) -> (0)} :
245 /// (memref<?x?xf32>, index, index) -> vector<128xf32>
246 /// }
247 /// ````
248 ///
249 /// Meaning that vector.transfer_read will be responsible of reading the slice
250 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
251 ///
252 AffineMap
253 makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
254  const DenseMap<Operation *, unsigned> &loopToVectorDim);
255 AffineMap
256 makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
257  const DenseMap<Operation *, unsigned> &loopToVectorDim);
258 
259 namespace matcher {
260 
261 /// Matches vector.transfer_read, vector.transfer_write and ops that return a
262 /// vector type that is a multiple of the sub-vector type. This allows passing
263 /// over other smaller vector types in the function and avoids interfering with
264 /// operations on those.
265 /// This is a first approximation, it can easily be extended in the future.
266 /// TODO: this could all be much simpler if we added a bit that a vector type to
267 /// mark that a vector is a strict super-vector but it still does not warrant
268 /// adding even 1 extra bit in the IR for now.
269 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
270 
271 } // namespace matcher
272 } // namespace mlir
273 
274 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
static AffineMap makePermutationMap(ArrayRef< Value > indices, const DenseMap< Operation *, unsigned > &enclosingLoopToVectorDim)
Constructs a permutation map from memref indices to vector dimension.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
Block represents an ordered list of Operations.
Definition: Block.h:30
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:80
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:37
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:133
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0