MLIR 22.0.0git
VectorUtils.h
Go to the documentation of this file.
1//===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
10#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
11
19#include "mlir/Support/LLVM.h"
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24namespace mlir {
25
26// Forward declarations.
27class AffineMap;
28class Block;
29class Location;
30class OpBuilder;
31class Operation;
32class ShapedType;
33class Value;
34class VectorType;
35class VectorTransferOpInterface;
36
37namespace affine {
38class AffineApplyOp;
39class AffineForOp;
40} // namespace affine
41
42namespace vector {
43/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
44/// the type of `source`.
45Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
46
47/// Returns two dims that are greater than one if the transposition is applied
48/// on a 2D slice. Otherwise, returns a failure.
49FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
50
51/// Return true if `vectorType` is a contiguous slice of `memrefType`,
52/// in the sense that it can be read/written from/to a contiguous area
53/// of the memref.
54///
55/// The leading unit dimensions of the vector type are ignored as they
56/// are not relevant to the result. Let N be the number of the vector
57/// dimensions after ignoring a leading sequence of unit ones.
58///
59/// For `vectorType` to be a contiguous slice of `memrefType`
60/// a) the N trailing dimensions of `memrefType` must be contiguous, and
61/// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
62/// must match.
63///
64/// Examples:
65///
66/// Ex.1 contiguous slice, perfect match
67/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
68/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
69/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
70/// Ex.3 non-contiguous slice, 2 != 3
71/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
72/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
73/// 2 != 3 (allowed)
74/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
75/// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
76/// 2 != 3 (allowed)
77/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
78/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
79/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
80/// Ex.7 contiguous slice, memref needs to be contiguous only in the last
81/// dimension
82/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
83/// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
84/// two dimensions, and it isn't
85/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
86bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
87
88/// Returns an iterator for all positions in the leading dimensions of `vType`
89/// up to the `targetRank`. If any leading dimension before the `targetRank` is
90/// scalable (so cannot be unrolled), it will return an iterator for positions
91/// up to the first scalable dimension.
92///
93/// If no leading dimensions can be unrolled an empty optional will be returned.
94///
95/// Examples:
96///
97/// For vType = vector<2x3x4> and targetRank = 1
98///
99/// The resulting iterator will yield:
100/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
101///
102/// For vType = vector<3x[4]x5> and targetRank = 0
103///
104/// The scalable dimension blocks unrolling so the iterator yields only:
105/// [0], [1], [2]
106///
107std::optional<StaticTileOffsetRange>
108createUnrollIterator(VectorType vType, int64_t targetRank = 1);
109
110/// Returns a functor (int64_t -> Value) which returns a constant vscale
111/// multiple.
112///
113/// Example:
114/// ```c++
115/// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
116/// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
117/// ```
119 Value vscale = nullptr;
120 return [loc, vscale, &rewriter](int64_t multiplier) mutable {
121 if (!vscale)
122 vscale = vector::VectorScaleOp::create(rewriter, loc);
123 return arith::MulIOp::create(
124 rewriter, loc, vscale,
125 arith::ConstantIndexOp::create(rewriter, loc, multiplier));
126 };
127}
128
129/// Returns a range over the dims (size and scalability) of a VectorType.
130inline auto getDims(VectorType vType) {
131 return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
132}
133
134/// A wrapper for getMixedSizes for vector.transfer_read and
135/// vector.transfer_write Ops (for source and destination, respectively).
136///
137/// Tensor and MemRef types implement their own, very similar version of
138/// getMixedSizes. This method will call the appropriate version (depending on
139/// `hasTensorSemantics`). It will also automatically extract the operand for
140/// which to call it on (source for "read" and destination for "write" ops).
141SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
142 Operation *xfer,
143 RewriterBase &rewriter);
144
145/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
146/// masked (i.e. inside `vector.mask` Op region). In particular:
147/// 1. Matches `SourceOp` operation, Op.
148/// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
149/// insertion point to avoid inserting new ops into the `vector.mask` Op
150/// region (which only allows one Op).
151/// 2.2 If Op is not masked, this step is skipped.
152/// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
153/// found in step 2.1.
154///
155/// This wrapper frees patterns from re-implementing the logic to update the
156/// insertion point when a maskable Op is masked. Such patterns are still
157/// responsible for providing an updated ("rewritten") version of:
158/// a. the source Op when mask _is not_ present,
159/// b. the source Op and the masking Op when mask _is_ present.
160/// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
161/// the return value will depend on the case above.
162template <class SourceOp>
165
166private:
167 LogicalResult matchAndRewrite(SourceOp sourceOp,
168 PatternRewriter &rewriter) const final {
169 auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
170 if (!maskableOp)
171 return failure();
172
173 Operation *rootOp = sourceOp;
174
175 // If this Op is masked, update the insertion point to avoid inserting into
176 // the vector.mask Op region.
177 OpBuilder::InsertionGuard guard(rewriter);
178 MaskingOpInterface maskOp;
179 if (maskableOp.isMasked()) {
180 maskOp = maskableOp.getMaskingOp();
181 rewriter.setInsertionPoint(maskOp);
182 rootOp = maskOp;
183 }
184
185 FailureOr<Value> newOp =
186 matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
187 if (failed(newOp))
188 return failure();
189
190 // Rewriting succeeded but there are no values to replace.
191 if (rootOp->getNumResults() == 0) {
192 rewriter.eraseOp(rootOp);
193 } else {
194 assert(*newOp != Value() &&
195 "Cannot replace an op's use with an empty value.");
196 rewriter.replaceOp(rootOp, *newOp);
197 }
198 return success();
199 }
200
201public:
202 // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
203 // latter is present, returns a replacement for `maskingOp`. Otherwise,
204 // returns a replacement for `sourceOp`.
205 virtual FailureOr<Value>
206 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
207 PatternRewriter &rewriter) const = 0;
208};
209
210/// Returns true if the input Vector type can be linearized.
211///
212/// Linearization is meant in the sense of flattening vectors, e.g.:
213/// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
214/// In this sense, Vectors that are either:
215/// * already linearized, or
216/// * contain more than 1 scalable dimensions,
217/// are not linearizable.
218bool isLinearizableVector(VectorType type);
219
220/// Creates a TransferReadOp from `source`.
221///
222/// If the shape of vector to read differs from the shape of the value being
223/// read, masking is used to avoid out-of-bounds accesses. Set
224/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
225/// instead of explicit masks.
226///
227/// Note: all read offsets are set to 0.
229 const VectorType &vecToReadTy,
230 std::optional<Value> padValue = std::nullopt,
231 bool useInBoundsInsteadOfMasking = false);
232
234 ArrayRef<int64_t> inputVectorSizes,
235 std::optional<Value> padValue = std::nullopt,
236 bool useInBoundsInsteadOfMasking = false,
237 ArrayRef<bool> inputScalableVecDims = {});
238
239/// Returns success if `inputVectorSizes` is a valid masking configuraion for
240/// given `shape`, i.e., it meets:
241/// 1. The numbers of elements in both array are equal.
242/// 2. `inputVectorSizes` does not have dynamic dimensions.
243/// 3. All the values in `inputVectorSizes` are greater than or equal to
244/// static sizes in `shape`.
246 ArrayRef<int64_t> inputVectorSizes);
247
248/// Generic utility for unrolling n-D vector operations to (n-1)-D operations.
249/// This handles the common pattern of:
250/// 1. Check if already 1-D. If so, return failure.
251/// 2. Check for scalable dimensions. If so, return failure.
252/// 3. Create poison initialized result.
253/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to
254/// create sub vectors.
255/// 5. Insert the sub vectors back into the final vector.
256/// 6. Replace the original op with the new result.
259
260LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
261 UnrollVectorOpFn unrollFn);
262
263/// Generic utility for unrolling values of type vector<NxAxBx...>
264/// to N values of type vector<AxBx...> using vector.extract. If the input
265/// is rank-1 or has leading scalable dimension, failure is returned.
266FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
267 RewriterBase &);
268
269} // namespace vector
270
271/// Constructs a permutation map of invariant memref indices to vector
272/// dimension.
273///
274/// If no index is found to be invariant, 0 is added to the permutation_map and
275/// corresponds to a vector broadcast along that dimension.
276///
277/// The implementation uses the knowledge of the mapping of loops to
278/// vector dimension. `loopToVectorDim` carries this information as a map with:
279/// - keys representing "vectorized enclosing loops";
280/// - values representing the corresponding vector dimension.
281/// Note that loopToVectorDim is a whole function map from which only enclosing
282/// loop information is extracted.
283///
284/// Prerequisites: `indices` belong to a vectorizable load or store operation
285/// (i.e. at most one invariant index along each AffineForOp of
286/// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
287/// load or store operation.
288///
289/// Example 1:
290/// The following MLIR snippet:
291///
292/// ```mlir
293/// affine.for %i3 = 0 to %0 {
294/// affine.for %i4 = 0 to %1 {
295/// affine.for %i5 = 0 to %2 {
296/// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
297/// }}}
298/// ```
299///
300/// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
301///
302/// ```mlir
303/// affine.for %i3 = 0 to %0 step 32 {
304/// affine.for %i4 = 0 to %1 {
305/// affine.for %i5 = 0 to %2 step 256 {
306/// %4 = vector.transfer_read %arg0, %i4, %i5, %i3
307/// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
308/// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
309/// }}}
310/// ```
311///
312/// Meaning that vector.transfer_read will be responsible for reading the slice:
313/// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
314///
315/// Example 2:
316/// The following MLIR snippet:
317///
318/// ```mlir
319/// %cst0 = arith.constant 0 : index
320/// affine.for %i0 = 0 to %0 {
321/// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
322/// }
323/// ```
324///
325/// may vectorize with {permutation_map: (d0) -> (0)} into:
326///
327/// ```mlir
328/// affine.for %i0 = 0 to %0 step 128 {
329/// %3 = vector.transfer_read %arg0, %c0_0, %c0_0
330/// {permutation_map: (d0, d1) -> (0)} :
331/// (memref<?x?xf32>, index, index) -> vector<128xf32>
332/// }
333/// ````
334///
335/// Meaning that vector.transfer_read will be responsible of reading the slice
336/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
337///
340 const DenseMap<Operation *, unsigned> &loopToVectorDim);
343 const DenseMap<Operation *, unsigned> &loopToVectorDim);
344
345namespace matcher {
346
347/// Matches vector.transfer_read, vector.transfer_write and ops that return a
348/// vector type that is a multiple of the sub-vector type. This allows passing
349/// over other smaller vector types in the function and avoids interfering with
350/// operations on those.
351/// This is a first approximation, it can easily be extended in the future.
352/// TODO: this could all be much simpler if we added a bit that a vector type to
353/// mark that a vector is a strict super-vector but it still does not warrant
354/// adding even 1 extra bit in the IR for now.
355bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
356
357} // namespace matcher
358} // namespace mlir
359
360#endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static AffineMap makePermutationMap(ArrayRef< Value > indices, const DenseMap< Operation *, unsigned > &enclosingLoopToVectorDim)
Constructs a permutation map from memref indices to vector dimension.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
Block represents an ordered list of Operations.
Definition Block.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
function_ref< Value(PatternRewriter &, Location, VectorType, int64_t)> UnrollVectorOpFn
Generic utility for unrolling n-D vector operations to (n-1)-D operations.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0